Added HNSW index type - #181

This commit is contained in:
Andrew Kane
2023-08-08 16:42:47 -07:00
parent 19a6c81367
commit 51d292c93d
29 changed files with 3927 additions and 7 deletions

View File

@@ -1,5 +1,6 @@
## 0.5.0 (unreleased)
- Added HNSW index type
- Added support for parallel index builds
- Added `l1_distance` function
- Added element-wise multiplication for vectors

View File

@@ -3,7 +3,7 @@ EXTVERSION = 0.4.4
MODULE_big = vector
DATA = $(wildcard sql/*--*.sql)
OBJS = src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o
OBJS = src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o
TESTS = $(wildcard test/sql/*.sql)
REGRESS = $(patsubst test/sql/%.sql,%,$(TESTS))

View File

@@ -1,7 +1,7 @@
EXTENSION = vector
EXTVERSION = 0.4.4
OBJS = src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj
OBJS = src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj
REGRESS = btree cast copy functions input ivfflat_cosine ivfflat_ip ivfflat_l2 ivfflat_options ivfflat_unlogged
REGRESS_OPTS = --inputdir=test --load-extension=vector

View File

@@ -18,3 +18,26 @@ CREATE AGGREGATE sum(vector) (
COMBINEFUNC = vector_add,
PARALLEL = SAFE
);
CREATE FUNCTION hnswhandler(internal) RETURNS index_am_handler
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnswhandler;
COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method';
CREATE OPERATOR CLASS vector_l2_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_l2_squared_distance(vector, vector);
CREATE OPERATOR CLASS vector_ip_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector);
CREATE OPERATOR CLASS vector_cosine_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector),
FUNCTION 2 vector_norm(vector);

View File

@@ -227,7 +227,7 @@ CREATE OPERATOR > (
RESTRICT = scalargtsel, JOIN = scalargtjoinsel
);
-- access method
-- access methods
CREATE FUNCTION ivfflathandler(internal) RETURNS index_am_handler
AS 'MODULE_PATHNAME' LANGUAGE C;
@@ -236,6 +236,13 @@ CREATE ACCESS METHOD ivfflat TYPE INDEX HANDLER ivfflathandler;
COMMENT ON ACCESS METHOD ivfflat IS 'ivfflat index access method';
CREATE FUNCTION hnswhandler(internal) RETURNS index_am_handler
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnswhandler;
COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method';
-- opclasses
CREATE OPERATOR CLASS vector_ops
@@ -267,3 +274,19 @@ CREATE OPERATOR CLASS vector_cosine_ops
FUNCTION 2 vector_norm(vector),
FUNCTION 3 vector_spherical_distance(vector, vector),
FUNCTION 4 vector_norm(vector);
CREATE OPERATOR CLASS vector_l2_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_l2_squared_distance(vector, vector);
CREATE OPERATOR CLASS vector_ip_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector);
CREATE OPERATOR CLASS vector_cosine_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector),
FUNCTION 2 vector_norm(vector);

224
src/hnsw.c Normal file
View File

@@ -0,0 +1,224 @@
#include "postgres.h"
#include <float.h>
#include <math.h>
#include "access/amapi.h"
#include "commands/vacuum.h"
#include "hnsw.h"
#include "utils/guc.h"
#include "utils/selfuncs.h"
#if PG_VERSION_NUM >= 120000
#include "commands/progress.h"
#endif
int hnsw_ef_search;
static relopt_kind hnsw_relopt_kind;
/*
* Initialize index options and variables
*/
void
HnswInit(void)
{
hnsw_relopt_kind = add_reloption_kind();
add_int_reloption(hnsw_relopt_kind, "m", "Max number of connections",
HNSW_DEFAULT_M, HNSW_MIN_M, HNSW_MAX_M
#if PG_VERSION_NUM >= 130000
,AccessExclusiveLock
#endif
);
add_int_reloption(hnsw_relopt_kind, "ef_construction", "Size of the dynamic candidate list for construction",
HNSW_DEFAULT_EF_CONSTRUCTION, HNSW_MIN_EF_CONSTRUCTION, HNSW_MAX_EF_CONSTRUCTION
#if PG_VERSION_NUM >= 130000
,AccessExclusiveLock
#endif
);
DefineCustomIntVariable("hnsw.ef_search", "Sets the size of the dynamic candidate list for search",
"Valid range is 10..1000.", &hnsw_ef_search,
HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL);
}
/*
* Get the name of index build phase
*/
#if PG_VERSION_NUM >= 120000
static char *
hnswbuildphasename(int64 phasenum)
{
switch (phasenum)
{
case PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE:
return "initializing";
case PROGRESS_HNSW_PHASE_LOAD:
return "loading tuples";
default:
return NULL;
}
}
#endif
/*
* Estimate the cost of an index scan
*/
static void
hnswcostestimate(PlannerInfo *root, IndexPath *path, double loop_count,
Cost *indexStartupCost, Cost *indexTotalCost,
Selectivity *indexSelectivity, double *indexCorrelation,
double *indexPages)
{
GenericCosts costs;
int m;
int entryLevel;
Relation index;
#if PG_VERSION_NUM < 120000
List *qinfos;
#endif
/* Never use index without order */
if (path->indexorderbys == NULL)
{
*indexStartupCost = DBL_MAX;
*indexTotalCost = DBL_MAX;
*indexSelectivity = 0;
*indexCorrelation = 0;
*indexPages = 0;
return;
}
MemSet(&costs, 0, sizeof(costs));
index = index_open(path->indexinfo->indexoid, NoLock);
m = HnswGetM(index);
index_close(index, NoLock);
/* Approximate entry level */
entryLevel = (int) -log(1.0 / path->indexinfo->tuples) * HnswGetMl(m);
/* TODO Improve estimate of visited tuples (currently underestimates) */
/* Account for number of tuples (or entry level), m, and ef_search */
costs.numIndexTuples = (entryLevel + 2) * m;
#if PG_VERSION_NUM >= 120000
genericcostestimate(root, path, loop_count, &costs);
#else
qinfos = deconstruct_indexquals(path);
genericcostestimate(root, path, loop_count, qinfos, &costs);
#endif
/* Use total cost since most work happens before first tuple is returned */
*indexStartupCost = costs.indexTotalCost;
*indexTotalCost = costs.indexTotalCost;
*indexSelectivity = costs.indexSelectivity;
*indexCorrelation = costs.indexCorrelation;
*indexPages = costs.numIndexPages;
}
/*
* Parse and validate the reloptions
*/
static bytea *
hnswoptions(Datum reloptions, bool validate)
{
static const relopt_parse_elt tab[] = {
{"m", RELOPT_TYPE_INT, offsetof(HnswOptions, m)},
{"ef_construction", RELOPT_TYPE_INT, offsetof(HnswOptions, efConstruction)},
};
#if PG_VERSION_NUM >= 130000
return (bytea *) build_reloptions(reloptions, validate,
hnsw_relopt_kind,
sizeof(HnswOptions),
tab, lengthof(tab));
#else
relopt_value *options;
int numoptions;
HnswOptions *rdopts;
options = parseRelOptions(reloptions, validate, hnsw_relopt_kind, &numoptions);
rdopts = allocateReloptStruct(sizeof(HnswOptions), options, numoptions);
fillRelOptions((void *) rdopts, sizeof(HnswOptions), options, numoptions,
validate, tab, lengthof(tab));
return (bytea *) rdopts;
#endif
}
/*
* Validate catalog entries for the specified operator class
*/
static bool
hnswvalidate(Oid opclassoid)
{
return true;
}
/*
* Define index handler
*
* See https://www.postgresql.org/docs/current/index-api.html
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(hnswhandler);
Datum
hnswhandler(PG_FUNCTION_ARGS)
{
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
amroutine->amstrategies = 0;
amroutine->amsupport = 2;
#if PG_VERSION_NUM >= 130000
amroutine->amoptsprocnum = 0;
#endif
amroutine->amcanorder = false;
amroutine->amcanorderbyop = true;
amroutine->amcanbackward = false; /* can change direction mid-scan */
amroutine->amcanunique = false;
amroutine->amcanmulticol = false;
amroutine->amoptionalkey = true;
amroutine->amsearcharray = false;
amroutine->amsearchnulls = false;
amroutine->amstorage = false;
amroutine->amclusterable = false;
amroutine->ampredlocks = false;
amroutine->amcanparallel = false;
amroutine->amcaninclude = false;
#if PG_VERSION_NUM >= 130000
amroutine->amusemaintenanceworkmem = false; /* not used during VACUUM */
amroutine->amparallelvacuumoptions = VACUUM_OPTION_PARALLEL_BULKDEL;
#endif
amroutine->amkeytype = InvalidOid;
/* Interface functions */
amroutine->ambuild = hnswbuild;
amroutine->ambuildempty = hnswbuildempty;
amroutine->aminsert = hnswinsert;
amroutine->ambulkdelete = hnswbulkdelete;
amroutine->amvacuumcleanup = hnswvacuumcleanup;
amroutine->amcanreturn = NULL; /* tuple not included in heapsort */
amroutine->amcostestimate = hnswcostestimate;
amroutine->amoptions = hnswoptions;
amroutine->amproperty = NULL; /* TODO AMPROP_DISTANCE_ORDERABLE */
#if PG_VERSION_NUM >= 120000
amroutine->ambuildphasename = hnswbuildphasename;
#endif
amroutine->amvalidate = hnswvalidate;
#if PG_VERSION_NUM >= 140000
amroutine->amadjustmembers = NULL;
#endif
amroutine->ambeginscan = hnswbeginscan;
amroutine->amrescan = hnswrescan;
amroutine->amgettuple = hnswgettuple;
amroutine->amgetbitmap = NULL;
amroutine->amendscan = hnswendscan;
amroutine->ammarkpos = NULL;
amroutine->amrestrpos = NULL;
/* Interface functions to support parallel index scans */
amroutine->amestimateparallelscan = NULL;
amroutine->aminitparallelscan = NULL;
amroutine->amparallelrescan = NULL;
PG_RETURN_POINTER(amroutine);
}

301
src/hnsw.h Normal file
View File

@@ -0,0 +1,301 @@
#ifndef HNSW_H
#define HNSW_H
#include "postgres.h"
#include "access/generic_xlog.h"
#include "access/reloptions.h"
#include "nodes/execnodes.h"
#include "port.h" /* for random() */
#include "utils/sampling.h"
#include "vector.h"
#if PG_VERSION_NUM < 110000
#error "Requires PostgreSQL 11+"
#endif
#define HNSW_MAX_DIM 2000
/* Support functions */
#define HNSW_DISTANCE_PROC 1
#define HNSW_NORM_PROC 2
#define HNSW_VERSION 1
#define HNSW_MAGIC_NUMBER 0xA953A953
#define HNSW_PAGE_ID 0xFF85
/* Preserved page numbers */
#define HNSW_METAPAGE_BLKNO 0
#define HNSW_HEAD_BLKNO 1 /* first element page */
#define HNSW_DEFAULT_M 16
#define HNSW_MIN_M 4
#define HNSW_MAX_M 100
#define HNSW_DEFAULT_EF_CONSTRUCTION 40
#define HNSW_MIN_EF_CONSTRUCTION 10
#define HNSW_MAX_EF_CONSTRUCTION 1000
#define HNSW_DEFAULT_EF_SEARCH 40
#define HNSW_MIN_EF_SEARCH 10
#define HNSW_MAX_EF_SEARCH 1000
#define HNSW_ELEMENT_TUPLE_TYPE 1
#define HNSW_NEIGHBOR_TUPLE_TYPE 2
/* Make graph robust against non-HOT updates */
#define HNSW_HEAPTIDS 10
/* Build phases */
/* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */
#define PROGRESS_HNSW_PHASE_LOAD 2
#define HNSW_ELEMENT_TUPLE_SIZE(_dim) MAXALIGN(offsetof(HnswElementTupleData, vec) + VECTOR_SIZE(_dim))
#define HNSW_NEIGHBOR_TUPLE_SIZE(level, m) MAXALIGN(offsetof(HnswNeighborTupleData, neighbors) + ((level) + 2) * (m) * sizeof(HnswNeighborTupleItem))
#define HnswPageGetOpaque(page) ((HnswPageOpaque) PageGetSpecialPointer(page))
#define HnswPageGetMeta(page) ((HnswMetaPageData *) PageGetContents(page))
#if PG_VERSION_NUM >= 150000
#define RandomDouble() pg_prng_double(&pg_global_prng_state)
#else
#define RandomDouble() (((double) random()) / MAX_RANDOM_VALUE)
#endif
#if PG_VERSION_NUM < 130000
#define list_delete_last(list) list_truncate(list, list_length(list) - 1)
#define list_sort(list, cmp) list_qsort(list, cmp)
#endif
#define HnswIsElementTuple(tup) ((tup)->type == HNSW_ELEMENT_TUPLE_TYPE)
#define HnswIsNeighborTuple(tup) ((tup)->type == HNSW_NEIGHBOR_TUPLE_TYPE)
#define HnswGetLayerM(m, layer) (layer == 0 ? m * 2 : m)
#define HnswGetMl(m) (1 / log(m))
/* Variables */
extern int hnsw_ef_search;
typedef struct HnswNeighborArray HnswNeighborArray;
typedef struct HnswElementData
{
List *heaptids;
uint8 level;
uint8 deleted;
HnswNeighborArray *neighbors;
BlockNumber blkno;
OffsetNumber offno;
OffsetNumber neighborOffno;
BlockNumber neighborPage;
Vector *vec;
} HnswElementData;
typedef HnswElementData * HnswElement;
typedef struct HnswCandidate
{
HnswElement element;
float distance;
} HnswCandidate;
typedef struct HnswNeighborArray
{
int length;
HnswCandidate *items;
} HnswNeighborArray;
typedef struct HnswUpdate
{
HnswCandidate hc;
int level;
int index;
} HnswUpdate;
typedef struct HnswPairingHeapNode
{
pairingheap_node ph_node;
HnswCandidate *inner;
} HnswPairingHeapNode;
/* HNSW index options */
typedef struct HnswOptions
{
int32 vl_len_; /* varlena header (do not touch directly!) */
int m; /* number of connections */
int efConstruction; /* size of dynamic candidate list */
} HnswOptions;
typedef struct HnswBuildState
{
/* Info */
Relation heap;
Relation index;
IndexInfo *indexInfo;
ForkNumber forkNum;
/* Settings */
int dimensions;
int m;
int efConstruction;
/* Statistics */
double indtuples;
double reltuples;
/* Support functions */
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
Oid collation;
/* Variables */
List *elements;
HnswElement entryPoint;
double ml;
int maxLevel;
double maxInMemoryElements;
bool flushed;
Vector *normvec;
/* Memory */
MemoryContext tmpCtx;
} HnswBuildState;
typedef struct HnswMetaPageData
{
uint32 magicNumber;
uint32 version;
uint32 dimensions;
uint16 m;
uint16 efConstruction;
BlockNumber entryBlkno;
OffsetNumber entryOffno;
int16 entryLevel;
BlockNumber insertPage;
} HnswMetaPageData;
typedef HnswMetaPageData * HnswMetaPage;
typedef struct HnswPageOpaqueData
{
BlockNumber nextblkno;
uint16 unused;
uint16 page_id; /* for identification of HNSW indexes */
} HnswPageOpaqueData;
typedef HnswPageOpaqueData * HnswPageOpaque;
typedef struct HnswElementTupleData
{
uint8 type;
uint8 level;
uint8 deleted;
uint8 unused;
ItemPointerData heaptids[HNSW_HEAPTIDS];
ItemPointerData neighbortid;
uint16 unused2;
Vector vec;
} HnswElementTupleData;
typedef HnswElementTupleData * HnswElementTuple;
typedef struct HnswNeighborTupleItem
{
ItemPointerData indextid;
uint16 unused;
float distance; /* improves performance of inserts */
} HnswNeighborTupleItem;
typedef struct HnswNeighborTupleData
{
uint8 type;
uint8 unused;
uint16 count;
HnswNeighborTupleItem neighbors[FLEXIBLE_ARRAY_MEMBER];
} HnswNeighborTupleData;
typedef HnswNeighborTupleData * HnswNeighborTuple;
typedef struct HnswScanOpaqueData
{
bool first;
Buffer buf;
List *w;
MemoryContext tmpCtx;
/* Support functions */
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
Oid collation;
} HnswScanOpaqueData;
typedef HnswScanOpaqueData * HnswScanOpaque;
typedef struct HnswVacuumState
{
/* Info */
Relation index;
IndexBulkDeleteResult *stats;
IndexBulkDeleteCallback callback;
void *callback_state;
/* Settings */
int m;
int efConstruction;
/* Support functions */
FmgrInfo *procinfo;
Oid collation;
/* Variables */
HTAB *deleted;
BufferAccessStrategy bas;
HnswNeighborTuple ntup;
HnswElementData highestPoint;
/* Memory */
MemoryContext tmpCtx;
} HnswVacuumState;
/* Methods */
int HnswGetM(Relation index);
int HnswGetEfConstruction(Relation index);
FmgrInfo *HnswOptionalProcInfo(Relation rel, uint16 procnum);
bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result);
void HnswCommitBuffer(Buffer buf, GenericXLogState *state);
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
void HnswInitPage(Buffer buf, Page page);
void HnswInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state);
void HnswInit(void);
List *HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, bool inserting, BlockNumber *skipPage, OffsetNumber *skipOffno);
HnswElement HnswGetEntryPoint(Relation index);
HnswElement HnswInitElement(ItemPointer tid, int m, double ml, int maxLevel);
void HnswFreeElement(HnswElement element);
HnswElement HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, List **updates, bool vacuuming);
HnswCandidate *HnswEntryCandidate(HnswElement em, Datum q, Relation rel, FmgrInfo *procinfo, Oid collation, bool loadvec);
void HnswUpdateMetaPage(Relation index, bool updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum);
void HnswSetNeighborTuple(HnswNeighborTuple ntup, HnswElement e, int m);
void HnswAddHeapTid(HnswElement element, ItemPointer heaptid);
void HnswInitNeighbors(HnswElement element, int m);
bool HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel);
void HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec);
void HnswSetElementTuple(HnswElementTuple etup, HnswElement element);
/* Index access methods */
IndexBuildResult *hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo);
void hnswbuildempty(Relation index);
bool hnswinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heap, IndexUniqueCheck checkUnique
#if PG_VERSION_NUM >= 140000
,bool indexUnchanged
#endif
,IndexInfo *indexInfo
);
IndexBulkDeleteResult *hnswbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state);
IndexBulkDeleteResult *hnswvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats);
IndexScanDesc hnswbeginscan(Relation index, int nkeys, int norderbys);
void hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys);
bool hnswgettuple(IndexScanDesc scan, ScanDirection dir);
void hnswendscan(IndexScanDesc scan);
/* Ensure fits in uint8 */
#define HnswGetMaxLevel(m) Min(((BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(HnswPageOpaqueData)) - offsetof(HnswNeighborTupleData, neighbors) - sizeof(ItemIdData)) / (sizeof(HnswNeighborTupleItem)) / m) - 2, 255)
#endif

506
src/hnswbuild.c Normal file
View File

@@ -0,0 +1,506 @@
#include "postgres.h"
#include <math.h>
#include "catalog/index.h"
#include "hnsw.h"
#include "miscadmin.h"
#include "lib/pairingheap.h"
#include "nodes/pg_list.h"
#include "storage/bufmgr.h"
#include "utils/memutils.h"
#if PG_VERSION_NUM >= 140000
#include "utils/backend_progress.h"
#elif PG_VERSION_NUM >= 120000
#include "pgstat.h"
#endif
#if PG_VERSION_NUM >= 120000
#include "access/tableam.h"
#include "commands/progress.h"
#else
#define PROGRESS_CREATEIDX_TUPLES_DONE 0
#endif
#if PG_VERSION_NUM >= 130000
#define CALLBACK_ITEM_POINTER ItemPointer tid
#else
#define CALLBACK_ITEM_POINTER HeapTuple hup
#endif
#if PG_VERSION_NUM >= 120000
#define UpdateProgress(index, val) pgstat_progress_update_param(index, val)
#else
#define UpdateProgress(index, val) ((void)val)
#endif
/*
* Create the metapage
*/
static void
CreateMetaPage(HnswBuildState * buildstate)
{
Relation index = buildstate->index;
ForkNumber forkNum = buildstate->forkNum;
Buffer buf;
Page page;
GenericXLogState *state;
HnswMetaPage metap;
buf = HnswNewBuffer(index, forkNum);
HnswInitRegisterPage(index, &buf, &page, &state);
/* Set metapage data */
metap = HnswPageGetMeta(page);
metap->magicNumber = HNSW_MAGIC_NUMBER;
metap->version = HNSW_VERSION;
metap->dimensions = buildstate->dimensions;
metap->m = buildstate->m;
metap->efConstruction = buildstate->efConstruction;
metap->entryBlkno = InvalidBlockNumber;
metap->entryOffno = InvalidOffsetNumber;
metap->insertPage = InvalidBlockNumber;
((PageHeader) page)->pd_lower =
((char *) metap + sizeof(HnswMetaPageData)) - (char *) page;
HnswCommitBuffer(buf, state);
}
/*
* Add a new page
*/
static void
HnswBuildAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, ForkNumber forkNum)
{
/* Add a new page */
Buffer newbuf = HnswNewBuffer(index, forkNum);
/* Update previous page */
HnswPageGetOpaque(*page)->nextblkno = BufferGetBlockNumber(newbuf);
/* Commit */
MarkBufferDirty(*buf);
GenericXLogFinish(*state);
UnlockReleaseBuffer(*buf);
/* Can take a while, so ensure we can interrupt */
/* Needs to be called when no buffer locks are held */
LockBuffer(newbuf, BUFFER_LOCK_UNLOCK);
CHECK_FOR_INTERRUPTS();
LockBuffer(newbuf, BUFFER_LOCK_EXCLUSIVE);
/* Prepare new page */
*buf = newbuf;
*state = GenericXLogStart(index);
*page = GenericXLogRegisterBuffer(*state, *buf, GENERIC_XLOG_FULL_IMAGE);
HnswInitPage(*buf, *page);
}
/*
* Create element pages
*/
static void
CreateElementPages(HnswBuildState * buildstate)
{
Relation index = buildstate->index;
ForkNumber forkNum = buildstate->forkNum;
int dimensions = buildstate->dimensions;
Size etupSize;
Size maxSize;
HnswElementTuple etup;
HnswNeighborTuple ntup;
BlockNumber insertPage;
Buffer buf;
Page page;
GenericXLogState *state;
ListCell *lc;
/* Calculate sizes */
maxSize = BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(HnswPageOpaqueData));
etupSize = HNSW_ELEMENT_TUPLE_SIZE(dimensions);
/* Allocate once */
etup = palloc0(etupSize);
ntup = palloc0(maxSize);
/* Prepare first page */
buf = HnswNewBuffer(index, forkNum);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, GENERIC_XLOG_FULL_IMAGE);
HnswInitPage(buf, page);
foreach(lc, buildstate->elements)
{
HnswElement element = lfirst(lc);
Size ntupSize;
Size combinedSize;
HnswSetElementTuple(etup, element);
/* Calculate sizes */
ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, buildstate->m);
combinedSize = etupSize + ntupSize + sizeof(ItemIdData);
/* Keep element and neighbors on the same page if possible */
if (PageGetFreeSpace(page) < etupSize || (combinedSize <= maxSize && PageGetFreeSpace(page) < combinedSize))
HnswBuildAppendPage(index, &buf, &page, &state, forkNum);
/* Calculate offsets */
element->blkno = BufferGetBlockNumber(buf);
element->offno = OffsetNumberNext(PageGetMaxOffsetNumber(page));
if (combinedSize <= maxSize)
{
element->neighborPage = element->blkno;
element->neighborOffno = OffsetNumberNext(element->offno);
}
else
{
element->neighborPage = element->blkno + 1;
element->neighborOffno = FirstOffsetNumber;
}
ItemPointerSet(&etup->neighbortid, element->neighborPage, element->neighborOffno);
/* Add element */
if (PageAddItem(page, (Item) etup, etupSize, InvalidOffsetNumber, false, false) != element->offno)
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
/* Add new page if needed */
if (PageGetFreeSpace(page) < ntupSize)
HnswBuildAppendPage(index, &buf, &page, &state, forkNum);
/* Add placeholder for neighbors */
if (PageAddItem(page, (Item) ntup, ntupSize, InvalidOffsetNumber, false, false) != element->neighborOffno)
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
}
insertPage = BufferGetBlockNumber(buf);
/* Commit */
MarkBufferDirty(buf);
GenericXLogFinish(state);
UnlockReleaseBuffer(buf);
HnswUpdateMetaPage(index, true, buildstate->entryPoint, insertPage, forkNum);
pfree(etup);
pfree(ntup);
}
/*
* Create neighbor pages
*/
static void
CreateNeighborPages(HnswBuildState * buildstate)
{
Relation index = buildstate->index;
ForkNumber forkNum = buildstate->forkNum;
int m = buildstate->m;
ListCell *lc;
HnswNeighborTuple ntup;
/* Allocate once */
ntup = palloc0(BLCKSZ);
foreach(lc, buildstate->elements)
{
HnswElement e = lfirst(lc);
Buffer buf;
Page page;
GenericXLogState *state;
Size ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(e->level, m);
/* Can take a while, so ensure we can interrupt */
/* Needs to be called when no buffer locks are held */
CHECK_FOR_INTERRUPTS();
buf = ReadBufferExtended(index, forkNum, e->neighborPage, RBM_NORMAL, NULL);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
HnswSetNeighborTuple(ntup, e, m);
if (!PageIndexTupleOverwrite(page, e->neighborOffno, (Item) ntup, ntupSize))
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
/* Commit */
MarkBufferDirty(buf);
GenericXLogFinish(state);
UnlockReleaseBuffer(buf);
}
pfree(ntup);
}
/*
* Free elements
*/
static void
FreeElements(HnswBuildState * buildstate)
{
ListCell *lc;
foreach(lc, buildstate->elements)
HnswFreeElement(lfirst(lc));
list_free(buildstate->elements);
}
/*
* Flush pages
*/
static void
FlushPages(HnswBuildState * buildstate)
{
CreateMetaPage(buildstate);
CreateElementPages(buildstate);
CreateNeighborPages(buildstate);
buildstate->flushed = true;
FreeElements(buildstate);
}
/*
* Insert tuple
*/
static bool
InsertTuple(Relation index, Datum *values, HnswElement element, HnswBuildState * buildstate, HnswElement * dup)
{
FmgrInfo *procinfo = buildstate->procinfo;
Oid collation = buildstate->collation;
HnswElement entryPoint = buildstate->entryPoint;
int efConstruction = buildstate->efConstruction;
int m = buildstate->m;
/* Detoast once for all calls */
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
/* Normalize if needed */
if (buildstate->normprocinfo != NULL)
{
if (!HnswNormValue(buildstate->normprocinfo, collation, &value, buildstate->normvec))
return false;
}
/* Copy value to element so accessible outside of memory context */
memcpy(element->vec, DatumGetVector(value), VECTOR_SIZE(buildstate->dimensions));
/* Insert element in graph */
*dup = HnswInsertElement(element, entryPoint, NULL, procinfo, collation, m, efConstruction, NULL, false);
/* Update entry point if needed */
if (*dup == NULL && (entryPoint == NULL || element->level > entryPoint->level))
buildstate->entryPoint = element;
UpdateProgress(PROGRESS_CREATEIDX_TUPLES_DONE, ++buildstate->indtuples);
return *dup == NULL;
}
/*
* Callback for table_index_build_scan
*/
static void
BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
bool *isnull, bool tupleIsAlive, void *state)
{
HnswBuildState *buildstate = (HnswBuildState *) state;
MemoryContext oldCtx;
HnswElement element;
HnswElement dup = NULL;
bool inserted;
#if PG_VERSION_NUM < 130000
ItemPointer tid = &hup->t_self;
#endif
/* Skip nulls */
if (isnull[0])
return;
if (buildstate->indtuples >= buildstate->maxInMemoryElements)
{
if (!buildstate->flushed)
{
ereport(NOTICE,
(errmsg("hnsw graph no longer fits into maintenance_work_mem after " INT64_FORMAT " tuples", (int64) buildstate->indtuples),
errdetail("Building will take significantly more time."),
errhint("Increase maintenance_work_mem to speed up builds.")));
FlushPages(buildstate);
}
oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx);
if (HnswInsertTuple(buildstate->index, values, isnull, tid, buildstate->heap))
UpdateProgress(PROGRESS_CREATEIDX_TUPLES_DONE, ++buildstate->indtuples);
/* Reset memory context */
MemoryContextSwitchTo(oldCtx);
MemoryContextReset(buildstate->tmpCtx);
return;
}
/* Allocate necessary memory outside of memory context */
element = HnswInitElement(tid, buildstate->m, buildstate->ml, buildstate->maxLevel);
element->vec = palloc(VECTOR_SIZE(buildstate->dimensions));
/* Use memory context since detoast can allocate */
oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx);
/* Insert tuple */
inserted = InsertTuple(index, values, element, buildstate, &dup);
/* Reset memory context */
MemoryContextSwitchTo(oldCtx);
MemoryContextReset(buildstate->tmpCtx);
/* Add outside memory context */
if (dup != NULL)
HnswAddHeapTid(dup, tid);
/* Add to buildstate or free */
if (inserted)
buildstate->elements = lappend(buildstate->elements, element);
else
HnswFreeElement(element);
}
/*
* Get the max number of elements that fit into maintenance_work_mem
*/
static double
HnswGetMaxInMemoryElements(int m, double ml, int dimensions)
{
Size elementSize = sizeof(HnswElementData);
double avgLevel = -log(0.5) * ml;
elementSize += sizeof(HnswNeighborArray) * (avgLevel + 1);
elementSize += sizeof(HnswCandidate) * (m * (avgLevel + 2));
elementSize += sizeof(ItemPointerData);
elementSize += VECTOR_SIZE(dimensions);
return (maintenance_work_mem * 1024L) / elementSize;
}
/*
* Initialize the build state
*/
static void
InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo, ForkNumber forkNum)
{
buildstate->heap = heap;
buildstate->index = index;
buildstate->indexInfo = indexInfo;
buildstate->forkNum = forkNum;
buildstate->m = HnswGetM(index);
buildstate->efConstruction = HnswGetEfConstruction(index);
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
/* Require column to have dimensions to be indexed */
if (buildstate->dimensions < 0)
elog(ERROR, "column does not have dimensions");
if (buildstate->dimensions > HNSW_MAX_DIM)
elog(ERROR, "column cannot have more than %d dimensions for hnsw index", HNSW_MAX_DIM);
buildstate->reltuples = 0;
buildstate->indtuples = 0;
/* Get support functions */
buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
buildstate->collation = index->rd_indcollation[0];
buildstate->elements = NIL;
buildstate->entryPoint = NULL;
buildstate->ml = HnswGetMl(buildstate->m);
buildstate->maxLevel = HnswGetMaxLevel(buildstate->m);
buildstate->maxInMemoryElements = HnswGetMaxInMemoryElements(buildstate->m, buildstate->ml, buildstate->dimensions);
buildstate->flushed = false;
/* Reuse for each tuple */
buildstate->normvec = InitVector(buildstate->dimensions);
buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext,
"Hnsw build temporary context",
ALLOCSET_DEFAULT_SIZES);
}
/*
* Free resources
*/
static void
FreeBuildState(HnswBuildState * buildstate)
{
pfree(buildstate->normvec);
MemoryContextDelete(buildstate->tmpCtx);
}
/*
* Build graph
*/
static void
BuildGraph(HnswBuildState * buildstate, ForkNumber forkNum)
{
UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_HNSW_PHASE_LOAD);
#if PG_VERSION_NUM >= 120000
buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, true, BuildCallback, (void *) buildstate, NULL);
#else
buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, BuildCallback, (void *) buildstate, NULL);
#endif
}
/*
* Build the index
*/
static void
BuildIndex(Relation heap, Relation index, IndexInfo *indexInfo,
HnswBuildState * buildstate, ForkNumber forkNum)
{
InitBuildState(buildstate, heap, index, indexInfo, forkNum);
if (buildstate->heap != NULL)
BuildGraph(buildstate, forkNum);
if (!buildstate->flushed)
FlushPages(buildstate);
FreeBuildState(buildstate);
}
/*
* Build the index for a logged table
*/
IndexBuildResult *
hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo)
{
IndexBuildResult *result;
HnswBuildState buildstate;
BuildIndex(heap, index, indexInfo, &buildstate, MAIN_FORKNUM);
result = (IndexBuildResult *) palloc(sizeof(IndexBuildResult));
result->heap_tuples = buildstate.reltuples;
result->index_tuples = buildstate.indtuples;
return result;
}
/*
* Build the index for an unlogged table
*/
void
hnswbuildempty(Relation index)
{
IndexInfo *indexInfo = BuildIndexInfo(index);
HnswBuildState buildstate;
BuildIndex(NULL, index, indexInfo, &buildstate, INIT_FORKNUM);
}

491
src/hnswinsert.c Normal file
View File

@@ -0,0 +1,491 @@
#include "postgres.h"
#include <math.h>
#include "hnsw.h"
#include "storage/bufmgr.h"
#include "storage/lmgr.h"
#include "utils/memutils.h"
/*
* Get the insert page
*/
static BlockNumber
GetInsertPage(Relation index)
{
Buffer buf;
Page page;
HnswMetaPage metap;
BlockNumber insertPage;
buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
metap = HnswPageGetMeta(page);
insertPage = metap->insertPage;
UnlockReleaseBuffer(buf);
return insertPage;
}
/*
* Check for a free offset
*/
static bool
HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size ntupSize, Buffer *nbuf, Page *npage, OffsetNumber *freeOffno, OffsetNumber *freeNeighborOffno, BlockNumber *firstFreePage)
{
OffsetNumber offno;
OffsetNumber maxoffno = PageGetMaxOffsetNumber(page);
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
HnswElementTuple etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno));
/* Skip neighbor tuples */
if (!HnswIsElementTuple(etup))
continue;
if (etup->deleted)
{
BlockNumber neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid);
OffsetNumber neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid);
ItemId itemid;
if (!BlockNumberIsValid(*firstFreePage))
*firstFreePage = neighborPage;
if (neighborPage == BufferGetBlockNumber(buf))
{
*nbuf = buf;
*npage = page;
}
else
{
*nbuf = ReadBuffer(index, neighborPage);
LockBuffer(*nbuf, BUFFER_LOCK_EXCLUSIVE);
/* Skip WAL for now */
*npage = BufferGetPage(*nbuf);
}
itemid = PageGetItemId(*npage, neighborOffno);
/* Check for space on neighbor tuple page */
if (PageGetFreeSpace(*npage) + ItemIdGetLength(itemid) - sizeof(ItemIdData) >= ntupSize)
{
*freeOffno = offno;
*freeNeighborOffno = neighborOffno;
return true;
}
else if (*nbuf != buf)
UnlockReleaseBuffer(*nbuf);
}
}
return false;
}
/*
* Add a new page
*/
static void
HnswInsertAppendPage(Relation index, Buffer *nbuf, Page *npage, GenericXLogState *state, Page page)
{
/* Add a new page */
LockRelationForExtension(index, ExclusiveLock);
*nbuf = HnswNewBuffer(index, MAIN_FORKNUM);
UnlockRelationForExtension(index, ExclusiveLock);
/* Init new page */
*npage = GenericXLogRegisterBuffer(state, *nbuf, GENERIC_XLOG_FULL_IMAGE);
HnswInitPage(*nbuf, *npage);
/* Update previous buffer */
HnswPageGetOpaque(page)->nextblkno = BufferGetBlockNumber(*nbuf);
}
/*
* Add to element and neighbor pages
*/
static void
WriteNewElementPages(Relation index, HnswElement e, int m)
{
Buffer buf;
Page page;
GenericXLogState *state;
Size etupSize;
Size ntupSize;
Size combinedSize;
HnswElementTuple etup;
BlockNumber insertPage = GetInsertPage(index);
BlockNumber originalInsertPage = insertPage;
int dimensions = e->vec->dim;
HnswNeighborTuple ntup;
Buffer nbuf;
Page npage;
OffsetNumber freeOffno = InvalidOffsetNumber;
OffsetNumber freeNeighborOffno = InvalidOffsetNumber;
BlockNumber firstFreePage = InvalidBlockNumber;
/* Calculate sizes */
etupSize = HNSW_ELEMENT_TUPLE_SIZE(dimensions);
ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(e->level, m);
combinedSize = etupSize + ntupSize + sizeof(ItemIdData);
/* Prepare element tuple */
etup = palloc0(etupSize);
HnswSetElementTuple(etup, e);
/* Prepare neighbor tuple */
ntup = palloc0(ntupSize);
HnswSetNeighborTuple(ntup, e, m);
/* Find a page to insert the item */
for (;;)
{
buf = ReadBuffer(index, insertPage);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
/* Space for both */
if (PageGetFreeSpace(page) >= combinedSize)
{
nbuf = buf;
npage = page;
break;
}
/* Space for element but not neighbors and last page */
if (PageGetFreeSpace(page) >= etupSize && !BlockNumberIsValid(HnswPageGetOpaque(page)->nextblkno))
{
HnswInsertAppendPage(index, &nbuf, &npage, state, page);
break;
}
/* Space from deleted item */
if (HnswFreeOffset(index, buf, page, e, ntupSize, &nbuf, &npage, &freeOffno, &freeNeighborOffno, &firstFreePage))
{
if (nbuf != buf)
npage = GenericXLogRegisterBuffer(state, nbuf, 0);
break;
}
insertPage = HnswPageGetOpaque(page)->nextblkno;
if (BlockNumberIsValid(insertPage))
{
/* Move to next page */
GenericXLogAbort(state);
UnlockReleaseBuffer(buf);
}
else
{
Buffer newbuf;
Page newpage;
HnswInsertAppendPage(index, &newbuf, &newpage, state, page);
/* Commit */
MarkBufferDirty(newbuf);
MarkBufferDirty(buf);
GenericXLogFinish(state);
/* Unlock previous buffer */
UnlockReleaseBuffer(buf);
/* Prepare new buffer */
state = GenericXLogStart(index);
buf = newbuf;
page = GenericXLogRegisterBuffer(state, buf, 0);
/* Create new page for neighbors if needed */
if (PageGetFreeSpace(page) < combinedSize)
HnswInsertAppendPage(index, &nbuf, &npage, state, page);
else
{
nbuf = buf;
npage = page;
}
break;
}
}
e->blkno = BufferGetBlockNumber(buf);
e->neighborPage = BufferGetBlockNumber(nbuf);
insertPage = e->neighborPage;
if (OffsetNumberIsValid(freeOffno))
{
e->offno = freeOffno;
e->neighborOffno = freeNeighborOffno;
}
else
{
e->offno = OffsetNumberNext(PageGetMaxOffsetNumber(page));
if (nbuf == buf)
e->neighborOffno = OffsetNumberNext(e->offno);
else
e->neighborOffno = FirstOffsetNumber;
}
ItemPointerSet(&etup->neighbortid, e->neighborPage, e->neighborOffno);
/* Add element and neighbors */
if (OffsetNumberIsValid(freeOffno))
{
if (!PageIndexTupleOverwrite(page, e->offno, (Item) etup, etupSize))
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
if (!PageIndexTupleOverwrite(npage, e->neighborOffno, (Item) ntup, ntupSize))
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
}
else
{
if (PageAddItem(page, (Item) etup, etupSize, InvalidOffsetNumber, false, false) != e->offno)
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
if (PageAddItem(npage, (Item) ntup, ntupSize, InvalidOffsetNumber, false, false) != e->neighborOffno)
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
}
/* Commit */
MarkBufferDirty(buf);
if (nbuf != buf)
MarkBufferDirty(nbuf);
GenericXLogFinish(state);
UnlockReleaseBuffer(buf);
if (nbuf != buf)
UnlockReleaseBuffer(nbuf);
/* Update the insert page */
if (insertPage != originalInsertPage && (!OffsetNumberIsValid(freeOffno) || firstFreePage == insertPage))
HnswUpdateMetaPage(index, false, NULL, insertPage, MAIN_FORKNUM);
}
/*
* Calculate index for update
*/
static int
HnswGetIndex(HnswUpdate * update, int m)
{
return (update->hc.element->level - update->level) * m + update->index;
}
/*
* Update neighbors
*/
static void
UpdateNeighborPages(Relation index, HnswElement e, int m, List *updates)
{
ListCell *lc;
/* Could update multiple at once for same element */
/* but should only happen a low percent of time, so keep simple for now */
foreach(lc, updates)
{
Buffer buf;
Page page;
GenericXLogState *state;
HnswUpdate *update = lfirst(lc);
ItemId itemid;
HnswNeighborTuple ntup;
Size ntupSize;
int idx;
OffsetNumber offno = update->hc.element->neighborOffno;
/* Register page */
buf = ReadBuffer(index, update->hc.element->neighborPage);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
/* Get tuple */
itemid = PageGetItemId(page, offno);
ntup = (HnswNeighborTuple) PageGetItem(page, itemid);
ntupSize = ItemIdGetLength(itemid);
/* Calculate index */
idx = HnswGetIndex(update, m);
/* Make robust to issues */
if (idx < ntup->count)
{
HnswNeighborTupleItem *neighbor = &ntup->neighbors[idx];
/* Update neighbor */
ItemPointerSet(&neighbor->indextid, e->blkno, e->offno);
neighbor->distance = update->hc.distance;
/* Overwrite tuple */
if (!PageIndexTupleOverwrite(page, offno, (Item) ntup, ntupSize))
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
/* Commit */
MarkBufferDirty(buf);
GenericXLogFinish(state);
}
else
GenericXLogAbort(state);
UnlockReleaseBuffer(buf);
}
}
/*
* Add a heap TID to an existing element
*/
static bool
HnswAddDuplicate(Relation index, HnswElement element, HnswElement dup)
{
Buffer buf;
Page page;
GenericXLogState *state;
Size etupSize = HNSW_ELEMENT_TUPLE_SIZE(dup->vec->dim);
HnswElementTuple etup;
int i;
/* Read page */
buf = ReadBuffer(index, dup->blkno);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
/* Find space */
etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, dup->offno));
for (i = 0; i < HNSW_HEAPTIDS; i++)
{
if (!ItemPointerIsValid(&etup->heaptids[i]))
break;
}
/* Either being deleted or we lost our chance to another backend */
if (i == 0 || i == HNSW_HEAPTIDS)
{
GenericXLogAbort(state);
UnlockReleaseBuffer(buf);
return false;
}
/* Add heap TID */
etup->heaptids[i] = *((ItemPointer) linitial(element->heaptids));
/* Overwrite tuple */
if (!PageIndexTupleOverwrite(page, dup->offno, (Item) etup, etupSize))
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
/* Commit */
MarkBufferDirty(buf);
GenericXLogFinish(state);
UnlockReleaseBuffer(buf);
return true;
}
/*
* Write changes to disk
*/
static void
WriteElement(Relation index, HnswElement element, int m, List *updates, HnswElement dup, HnswElement entryPoint)
{
/* Try to add to existing page */
if (dup != NULL)
{
if (HnswAddDuplicate(index, element, dup))
return;
}
/* If fails, take this path */
WriteNewElementPages(index, element, m);
UpdateNeighborPages(index, element, m, updates);
/* Update metapage if needed */
if (entryPoint == NULL || element->level > entryPoint->level)
HnswUpdateMetaPage(index, true, element, InvalidBlockNumber, MAIN_FORKNUM);
}
/*
* Insert a tuple into the index
*/
bool
HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel)
{
Datum value;
FmgrInfo *normprocinfo;
HnswElement entryPoint;
HnswElement element;
int m = HnswGetM(index);
int efConstruction = HnswGetEfConstruction(index);
double ml = HnswGetMl(m);
FmgrInfo *procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
Oid collation = index->rd_indcollation[0];
List *updates = NIL;
HnswElement dup;
/* Detoast once for all calls */
value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
/* Normalize if needed */
normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
if (normprocinfo != NULL)
{
if (!HnswNormValue(normprocinfo, collation, &value, NULL))
return false;
}
/* Create an element */
element = HnswInitElement(heap_tid, m, ml, HnswGetMaxLevel(m));
element->vec = DatumGetVector(value);
/* Get entry point */
entryPoint = HnswGetEntryPoint(index);
/* Insert element in graph */
dup = HnswInsertElement(element, entryPoint, index, procinfo, collation, m, efConstruction, &updates, false);
/* Write to disk */
WriteElement(index, element, m, updates, dup, entryPoint);
return true;
}
/*
* Insert a tuple into the index
*/
bool
hnswinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid,
Relation heap, IndexUniqueCheck checkUnique
#if PG_VERSION_NUM >= 140000
,bool indexUnchanged
#endif
,IndexInfo *indexInfo
)
{
MemoryContext oldCtx;
MemoryContext insertCtx;
/* Skip nulls */
if (isnull[0])
return false;
/* Create memory context */
insertCtx = AllocSetContextCreate(CurrentMemoryContext,
"Hnsw insert temporary context",
ALLOCSET_DEFAULT_SIZES);
oldCtx = MemoryContextSwitchTo(insertCtx);
/* Insert tuple */
HnswInsertTuple(index, values, isnull, heap_tid, heap);
/* Delete memory context */
MemoryContextSwitchTo(oldCtx);
MemoryContextDelete(insertCtx);
return false;
}

212
src/hnswscan.c Normal file
View File

@@ -0,0 +1,212 @@
#include "postgres.h"
#include "access/relscan.h"
#include "hnsw.h"
#include "pgstat.h"
#include "storage/bufmgr.h"
#include "utils/memutils.h"
/*
* Algorithm 5 from paper
*/
static void
GetScanItems(IndexScanDesc scan, Datum q)
{
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
Relation index = scan->indexRelation;
FmgrInfo *procinfo = so->procinfo;
Oid collation = so->collation;
List *ep = NIL;
List *w;
HnswElement entryPoint = HnswGetEntryPoint(index);
if (entryPoint == NULL)
return;
ep = lappend(ep, HnswEntryCandidate(entryPoint, q, index, procinfo, collation, false));
for (int lc = entryPoint->level; lc >= 1; lc--)
{
w = HnswSearchLayer(q, ep, 1, lc, index, procinfo, collation, false, NULL, NULL);
ep = w;
}
so->w = HnswSearchLayer(q, ep, hnsw_ef_search, 0, index, procinfo, collation, false, NULL, NULL);
}
/*
* Get dimensions from metapage
*/
static int
GetDimensions(Relation index)
{
Buffer buf;
Page page;
HnswMetaPage metap;
int dimensions;
buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
metap = HnswPageGetMeta(page);
dimensions = metap->dimensions;
UnlockReleaseBuffer(buf);
return dimensions;
}
/*
* Prepare for an index scan
*/
IndexScanDesc
hnswbeginscan(Relation index, int nkeys, int norderbys)
{
IndexScanDesc scan;
HnswScanOpaque so;
scan = RelationGetIndexScan(index, nkeys, norderbys);
so = (HnswScanOpaque) palloc(sizeof(HnswScanOpaqueData));
so->buf = InvalidBuffer;
so->first = true;
so->tmpCtx = AllocSetContextCreate(CurrentMemoryContext,
"Hnsw scan temporary context",
ALLOCSET_DEFAULT_SIZES);
/* Set support functions */
so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
so->collation = index->rd_indcollation[0];
scan->opaque = so;
return scan;
}
/*
* Start or restart an index scan
*/
void
hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys)
{
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
so->first = true;
MemoryContextReset(so->tmpCtx);
if (keys && scan->numberOfKeys > 0)
memmove(scan->keyData, keys, scan->numberOfKeys * sizeof(ScanKeyData));
if (orderbys && scan->numberOfOrderBys > 0)
memmove(scan->orderByData, orderbys, scan->numberOfOrderBys * sizeof(ScanKeyData));
}
/*
* Fetch the next tuple in the given scan
*/
bool
hnswgettuple(IndexScanDesc scan, ScanDirection dir)
{
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
MemoryContext oldCtx = MemoryContextSwitchTo(so->tmpCtx);
/*
* Index can be used to scan backward, but Postgres doesn't support
* backward scan on operators
*/
Assert(ScanDirectionIsForward(dir));
if (so->first)
{
Datum value;
/* Count index scan for stats */
pgstat_count_index_scan(scan->indexRelation);
/* Safety check */
if (scan->orderByData == NULL)
elog(ERROR, "cannot scan hnsw index without order");
if (scan->orderByData->sk_flags & SK_ISNULL)
value = PointerGetDatum(InitVector(GetDimensions(scan->indexRelation)));
else
{
value = scan->orderByData->sk_argument;
/* Value should not be compressed or toasted */
Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value)));
Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value)));
/* Fine if normalization fails */
if (so->normprocinfo != NULL)
HnswNormValue(so->normprocinfo, so->collation, &value, NULL);
}
GetScanItems(scan, value);
so->first = false;
}
while (list_length(so->w) > 0)
{
HnswCandidate *hc = llast(so->w);
ItemPointer tid;
BlockNumber indexblkno;
/* Move to next element if no valid heap tids */
if (list_length(hc->element->heaptids) == 0)
{
so->w = list_delete_last(so->w);
continue;
}
tid = llast(hc->element->heaptids);
indexblkno = hc->element->blkno;
hc->element->heaptids = list_delete_last(hc->element->heaptids);
MemoryContextSwitchTo(oldCtx);
#if PG_VERSION_NUM >= 120000
scan->xs_heaptid = *tid;
#else
scan->xs_ctup.t_self = *tid;
#endif
if (BufferIsValid(so->buf))
ReleaseBuffer(so->buf);
/*
* An index scan must maintain a pin on the index page holding the
* item last returned by amgettuple
*
* https://www.postgresql.org/docs/current/index-locking.html
*/
so->buf = ReadBuffer(scan->indexRelation, indexblkno);
scan->xs_recheckorderby = false;
return true;
}
MemoryContextSwitchTo(oldCtx);
return false;
}
/*
* End a scan and release resources
*/
void
hnswendscan(IndexScanDesc scan)
{
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
/* Release pin */
if (BufferIsValid(so->buf))
ReleaseBuffer(so->buf);
MemoryContextDelete(so->tmpCtx);
pfree(so);
scan->opaque = NULL;
}

982
src/hnswutils.c Normal file
View File

@@ -0,0 +1,982 @@
#include "postgres.h"
#include <math.h>
#include "hnsw.h"
#include "storage/bufmgr.h"
#include "vector.h"
/*
* Get the number of connection in the index
*/
int
HnswGetM(Relation index)
{
HnswOptions *opts = (HnswOptions *) index->rd_options;
if (opts)
return opts->m;
return HNSW_DEFAULT_M;
}
/*
* Get the size of the dynamic candidate list in the index
*/
int
HnswGetEfConstruction(Relation index)
{
HnswOptions *opts = (HnswOptions *) index->rd_options;
if (opts)
return opts->efConstruction;
return HNSW_DEFAULT_EF_CONSTRUCTION;
}
/*
* Get proc
*/
FmgrInfo *
HnswOptionalProcInfo(Relation rel, uint16 procnum)
{
if (!OidIsValid(index_getprocid(rel, 1, procnum)))
return NULL;
return index_getprocinfo(rel, 1, procnum);
}
/*
* Divide by the norm
*
* Returns false if value should not be indexed
*
* The caller needs to free the pointer stored in value
* if it's different than the original value
*/
bool
HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result)
{
double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
if (norm > 0)
{
Vector *v = DatumGetVector(*value);
if (result == NULL)
result = InitVector(v->dim);
for (int i = 0; i < v->dim; i++)
result->x[i] = v->x[i] / norm;
*value = PointerGetDatum(result);
return true;
}
return false;
}
/*
* New buffer
*/
Buffer
HnswNewBuffer(Relation index, ForkNumber forkNum)
{
Buffer buf = ReadBufferExtended(index, forkNum, P_NEW, RBM_NORMAL, NULL);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
return buf;
}
/*
* Init page
*/
void
HnswInitPage(Buffer buf, Page page)
{
PageInit(page, BufferGetPageSize(buf), sizeof(HnswPageOpaqueData));
HnswPageGetOpaque(page)->nextblkno = InvalidBlockNumber;
HnswPageGetOpaque(page)->page_id = HNSW_PAGE_ID;
}
/*
* Init and register page
*/
void
HnswInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state)
{
*state = GenericXLogStart(index);
*page = GenericXLogRegisterBuffer(*state, *buf, GENERIC_XLOG_FULL_IMAGE);
HnswInitPage(*buf, *page);
}
/*
* Commit buffer
*/
void
HnswCommitBuffer(Buffer buf, GenericXLogState *state)
{
MarkBufferDirty(buf);
GenericXLogFinish(state);
UnlockReleaseBuffer(buf);
}
/*
* Allocate neighbors
*/
void
HnswInitNeighbors(HnswElement element, int m)
{
int level = element->level;
element->neighbors = palloc(sizeof(HnswNeighborArray) * (level + 1));
for (int lc = 0; lc <= level; lc++)
{
HnswNeighborArray *a;
int lm = HnswGetLayerM(m, lc);
a = &element->neighbors[lc];
a->length = 0;
a->items = palloc(sizeof(HnswCandidate) * lm);
}
}
/*
* Allocate an element
*/
HnswElement
HnswInitElement(ItemPointer heaptid, int m, double ml, int maxLevel)
{
HnswElement element = palloc(sizeof(HnswElementData));
int level = (int) (-log(RandomDouble()) * ml);
/* Cap level */
if (level > maxLevel)
level = maxLevel;
element->heaptids = NIL;
HnswAddHeapTid(element, heaptid);
element->level = level;
element->deleted = 0;
HnswInitNeighbors(element, m);
return element;
}
/*
* Free an element
*/
void
HnswFreeElement(HnswElement element)
{
list_free_deep(element->heaptids);
for (int lc = 0; lc <= element->level; lc++)
pfree(element->neighbors[lc].items);
pfree(element->neighbors);
pfree(element->vec);
pfree(element);
}
/*
* Add a heap TID to an element
*/
void
HnswAddHeapTid(HnswElement element, ItemPointer heaptid)
{
ItemPointer copy = palloc(sizeof(ItemPointerData));
ItemPointerCopy(heaptid, copy);
element->heaptids = lappend(element->heaptids, copy);
}
/*
* Allocate an element from block and offset numbers
*/
static HnswElement
InitElementFromBlock(BlockNumber blkno, OffsetNumber offno)
{
HnswElement element = palloc(sizeof(HnswElementData));
element->blkno = blkno;
element->offno = offno;
element->neighbors = NULL;
element->vec = NULL;
return element;
}
/*
* Get the entry point
*/
HnswElement
HnswGetEntryPoint(Relation index)
{
Buffer buf;
Page page;
HnswMetaPage metap;
HnswElement entryPoint = NULL;
buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
metap = HnswPageGetMeta(page);
if (BlockNumberIsValid(metap->entryBlkno))
entryPoint = InitElementFromBlock(metap->entryBlkno, metap->entryOffno);
UnlockReleaseBuffer(buf);
return entryPoint;
}
/*
* Update the metapage
*/
void
HnswUpdateMetaPage(Relation index, bool updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum)
{
Buffer buf;
Page page;
GenericXLogState *state;
HnswMetaPage metap;
buf = ReadBufferExtended(index, forkNum, HNSW_METAPAGE_BLKNO, RBM_NORMAL, NULL);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
metap = HnswPageGetMeta(page);
if (updateEntry)
{
if (entryPoint == NULL)
{
metap->entryBlkno = InvalidBlockNumber;
metap->entryOffno = InvalidOffsetNumber;
metap->entryLevel = -1;
}
else
{
metap->entryBlkno = entryPoint->blkno;
metap->entryOffno = entryPoint->offno;
metap->entryLevel = entryPoint->level;
}
}
if (BlockNumberIsValid(insertPage))
metap->insertPage = insertPage;
HnswCommitBuffer(buf, state);
}
/*
* Set element tuple, except for neighbor info
*/
void
HnswSetElementTuple(HnswElementTuple etup, HnswElement element)
{
etup->type = HNSW_ELEMENT_TUPLE_TYPE;
etup->level = element->level;
etup->deleted = 0;
for (int i = 0; i < HNSW_HEAPTIDS; i++)
{
if (i < list_length(element->heaptids))
etup->heaptids[i] = *((ItemPointer) list_nth(element->heaptids, i));
else
ItemPointerSetInvalid(&etup->heaptids[i]);
}
memcpy(&etup->vec, element->vec, VECTOR_SIZE(element->vec->dim));
}
/*
* Set neighbor tuple
*/
void
HnswSetNeighborTuple(HnswNeighborTuple ntup, HnswElement e, int m)
{
int idx = 0;
ntup->type = HNSW_NEIGHBOR_TUPLE_TYPE;
for (int lc = e->level; lc >= 0; lc--)
{
HnswNeighborArray *neighbors = &e->neighbors[lc];
int lm = HnswGetLayerM(m, lc);
for (int i = 0; i < lm; i++)
{
HnswNeighborTupleItem *neighbor = &ntup->neighbors[idx++];
if (i < neighbors->length)
{
HnswCandidate *hc = &neighbors->items[i];
ItemPointerSet(&neighbor->indextid, hc->element->blkno, hc->element->offno);
neighbor->distance = hc->distance;
}
else
{
ItemPointerSetInvalid(&neighbor->indextid);
neighbor->distance = NAN;
}
}
}
ntup->count = idx;
}
/*
* Load neighbors from page
*/
static void
LoadNeighborsFromPage(HnswElement element, Relation index, Page page)
{
HnswNeighborTuple ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno));
int m = HnswGetM(index);
int neighborCount = (element->level + 2) * m;
Assert(HnswIsNeighborTuple(ntup));
HnswInitNeighbors(element, m);
/* Ensure expected neighbors */
if (ntup->count != neighborCount)
return;
for (int i = 0; i < neighborCount; i++)
{
HnswElement e;
int level;
HnswCandidate *hc;
HnswNeighborTupleItem *neighbor;
HnswNeighborArray *neighbors;
neighbor = &ntup->neighbors[i];
if (!ItemPointerIsValid(&neighbor->indextid))
continue;
e = InitElementFromBlock(ItemPointerGetBlockNumber(&neighbor->indextid), ItemPointerGetOffsetNumber(&neighbor->indextid));
/* Calculate level based on offset */
level = element->level - i / m;
if (level < 0)
level = 0;
neighbors = &element->neighbors[level];
hc = &neighbors->items[neighbors->length++];
hc->element = e;
hc->distance = neighbor->distance;
}
}
/*
* Load neighbors
*/
static void
LoadNeighbors(HnswElement element, Relation index)
{
Buffer buf;
Page page;
buf = ReadBuffer(index, element->neighborPage);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
LoadNeighborsFromPage(element, index, page);
UnlockReleaseBuffer(buf);
}
/*
* Load an element and optionally get its distance from q
*/
void
HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec)
{
Buffer buf;
Page page;
HnswElementTuple etup;
/* Read vector */
buf = ReadBuffer(index, element->blkno);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, element->offno));
Assert(HnswIsElementTuple(etup));
/* Load element */
element->heaptids = NIL;
for (int i = 0; i < HNSW_HEAPTIDS; i++)
{
/* Can stop at first invalid */
if (!ItemPointerIsValid(&etup->heaptids[i]))
break;
HnswAddHeapTid(element, &etup->heaptids[i]);
}
element->level = etup->level;
element->neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid);
element->neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid);
element->deleted = etup->deleted;
if (loadVec)
{
element->vec = palloc(VECTOR_SIZE(etup->vec.dim));
memcpy(element->vec, &etup->vec, VECTOR_SIZE(etup->vec.dim));
}
/* Calculate distance */
if (distance != NULL)
*distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->vec)));
UnlockReleaseBuffer(buf);
}
/*
* Get the distance for a candidate
*/
static float
GetCandidateDistance(HnswCandidate * hc, Datum q, FmgrInfo *procinfo, Oid collation)
{
return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, PointerGetDatum(hc->element->vec)));
}
/*
* Create a candidate for the entry point
*/
HnswCandidate *
HnswEntryCandidate(HnswElement entryPoint, Datum q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadvec)
{
HnswCandidate *hc = palloc(sizeof(HnswCandidate));
hc->element = entryPoint;
if (index == NULL)
hc->distance = GetCandidateDistance(hc, q, procinfo, collation);
else
HnswLoadElement(hc->element, &hc->distance, &q, index, procinfo, collation, loadvec);
return hc;
}
/*
* Compare candidate distances
*/
static int
CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg)
{
if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance)
return 1;
if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance)
return -1;
return 0;
}
/*
* Compare candidate distances
*/
static int
CompareFurthestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg)
{
if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance)
return -1;
if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance)
return 1;
return 0;
}
/*
* Create a pairing heap node for a candidate
*/
static HnswPairingHeapNode *
CreatePairingHeapNode(HnswCandidate * c)
{
HnswPairingHeapNode *node = palloc(sizeof(HnswPairingHeapNode));
node->inner = c;
return node;
}
/*
* Add to visited
*/
static inline void
AddToVisited(HTAB *v, HnswCandidate * hc, Relation index, bool *found)
{
if (index == NULL)
hash_search(v, &hc->element, HASH_ENTER, found);
else
{
ItemPointerData indextid;
ItemPointerSet(&indextid, hc->element->blkno, hc->element->offno);
hash_search(v, &indextid, HASH_ENTER, found);
}
}
/*
* Algorithm 2 from paper
*/
List *
HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, bool inserting, BlockNumber *skipPage, OffsetNumber *skipOffno)
{
ListCell *lc2;
List *w = NIL;
pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL);
pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL);
int wlen = 0;
HASHCTL hash_ctl;
HTAB *v;
/* Create hash table */
if (index == NULL)
{
hash_ctl.keysize = sizeof(HnswElement *);
hash_ctl.entrysize = sizeof(HnswElement *);
}
else
{
hash_ctl.keysize = sizeof(ItemPointerData);
hash_ctl.entrysize = sizeof(ItemPointerData);
}
hash_ctl.hcxt = CurrentMemoryContext;
v = hash_create("hnsw visited", 256, &hash_ctl, HASH_ELEM | HASH_BLOBS | HASH_CONTEXT);
/* Add entry points to v, C, and W */
foreach(lc2, ep)
{
HnswCandidate *hc = (HnswCandidate *) lfirst(lc2);
AddToVisited(v, hc, index, NULL);
pairingheap_add(C, &(CreatePairingHeapNode(hc)->ph_node));
pairingheap_add(W, &(CreatePairingHeapNode(hc)->ph_node));
wlen++;
}
while (!pairingheap_is_empty(C))
{
HnswNeighborArray *neighborhood;
HnswCandidate *c = ((HnswPairingHeapNode *) pairingheap_remove_first(C))->inner;
HnswCandidate *f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner;
if (c->distance > f->distance)
break;
if (c->element->neighbors == NULL)
LoadNeighbors(c->element, index);
/* Get the neighborhood at layer lc */
neighborhood = &c->element->neighbors[lc];
for (int i = 0; i < neighborhood->length; i++)
{
HnswCandidate *e = &neighborhood->items[i];
bool visited;
AddToVisited(v, e, index, &visited);
if (!visited)
{
float eDistance;
f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner;
if (index == NULL)
eDistance = GetCandidateDistance(e, q, procinfo, collation);
else
HnswLoadElement(e->element, &eDistance, &q, index, procinfo, collation, inserting);
/* Skip if fully deleted */
if (e->element->deleted)
continue;
/* Skip for inserts if deleting */
if (inserting && list_length(e->element->heaptids) == 0)
continue;
/* Skip self for vacuuming update */
if (skipPage != NULL && e->element->neighborPage == *skipPage && e->element->neighborOffno == *skipOffno)
continue;
/* Make robust to issues */
if (e->element->level < lc)
continue;
if (eDistance < f->distance || wlen < ef)
{
/* Copy e */
HnswCandidate *ec = palloc(sizeof(HnswCandidate));
ec->element = e->element;
ec->distance = eDistance;
pairingheap_add(C, &(CreatePairingHeapNode(ec)->ph_node));
pairingheap_add(W, &(CreatePairingHeapNode(ec)->ph_node));
wlen++;
/* No need to decrement wlen */
if (wlen > ef)
pairingheap_remove_first(W);
}
}
}
}
/* Add each element of W to w */
while (!pairingheap_is_empty(W))
{
HnswCandidate *hc = ((HnswPairingHeapNode *) pairingheap_remove_first(W))->inner;
w = lappend(w, hc);
}
return w;
}
/*
* Calculate the distance between elements
*/
static float
HnswGetDistance(HnswElement a, HnswElement b, int lc, FmgrInfo *procinfo, Oid collation)
{
/* Look for cached distance */
if (a->neighbors != NULL)
{
Assert(a->level >= lc);
for (int i = 0; i < a->neighbors[lc].length; i++)
{
if (a->neighbors[lc].items[i].element == b)
return a->neighbors[lc].items[i].distance;
}
}
if (b->neighbors != NULL)
{
Assert(b->level >= lc);
for (int i = 0; i < b->neighbors[lc].length; i++)
{
if (b->neighbors[lc].items[i].element == a)
return b->neighbors[lc].items[i].distance;
}
}
return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(a->vec), PointerGetDatum(b->vec)));
}
/*
* Check if an element is closer to q than any element from R
*/
static bool
CheckElementCloser(HnswCandidate * e, List *r, int lc, FmgrInfo *procinfo, Oid collation)
{
ListCell *lc2;
foreach(lc2, r)
{
HnswCandidate *ri = lfirst(lc2);
float distance = HnswGetDistance(e->element, ri->element, lc, procinfo, collation);
if (distance <= e->distance)
return false;
}
return true;
}
/*
* Algorithm 4 from paper
*/
static List *
SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswCandidate * *pruned)
{
List *r = NIL;
List *w = list_copy(c);
pairingheap *wd;
if (list_length(w) < m)
return w;
wd = pairingheap_allocate(CompareNearestCandidates, NULL);
while (list_length(w) > 0 && list_length(r) < m)
{
/* Assumes w is already ordered desc */
HnswCandidate *e = llast(w);
bool closer;
w = list_delete_last(w);
closer = CheckElementCloser(e, r, lc, procinfo, collation);
if (closer)
r = lappend(r, e);
else
pairingheap_add(wd, &(CreatePairingHeapNode(e)->ph_node));
}
/* Keep pruned connections */
while (!pairingheap_is_empty(wd) && list_length(r) < m)
r = lappend(r, ((HnswPairingHeapNode *) pairingheap_remove_first(wd))->inner);
/* Return pruned for update connections */
if (pruned != NULL)
{
if (!pairingheap_is_empty(wd))
*pruned = ((HnswPairingHeapNode *) pairingheap_first(wd))->inner;
else
*pruned = linitial(w);
}
return r;
}
/*
* Find duplicate element
*/
static HnswElement
HnswFindDuplicate(HnswElement e, List *neighbors)
{
ListCell *lc;
foreach(lc, neighbors)
{
HnswCandidate *neighbor = lfirst(lc);
/* Exit early since ordered by distance */
if (vector_cmp_internal(e->vec, neighbor->element->vec) != 0)
break;
/* Check for space */
if (list_length(neighbor->element->heaptids) < HNSW_HEAPTIDS)
return neighbor->element;
}
return NULL;
}
/*
* Add connections
*/
static void
AddConnections(HnswElement element, List *neighbors, int m, int lc)
{
ListCell *lc2;
HnswNeighborArray *a = &element->neighbors[lc];
foreach(lc2, neighbors)
a->items[a->length++] = *((HnswCandidate *) lfirst(lc2));
}
/*
* Compare candidate distances
*/
static int
#if PG_VERSION_NUM >= 130000
CompareCandidateDistances(const ListCell *a, const ListCell *b)
#else
CompareCandidateDistances(const void *a, const void *b)
#endif
{
HnswCandidate *hca = lfirst((ListCell *) a);
HnswCandidate *hcb = lfirst((ListCell *) b);
if (hca->distance < hcb->distance)
return 1;
if (hca->distance > hcb->distance)
return -1;
return 0;
}
/*
* Create update
*/
static HnswUpdate *
CreateUpdate(HnswCandidate * hc, int level, int index)
{
HnswUpdate *update = palloc(sizeof(HnswUpdate));
update->hc = *hc;
update->level = level;
update->index = index;
return update;
}
/*
* Update connections
*/
static void
UpdateConnections(HnswElement element, List *neighbors, int m, int lc, List **updates, Relation index, FmgrInfo *procinfo, Oid collation)
{
ListCell *lc2;
foreach(lc2, neighbors)
{
HnswCandidate *hc = (HnswCandidate *) lfirst(lc2);
HnswNeighborArray *currentNeighbors = &hc->element->neighbors[lc];
HnswCandidate hc2;
hc2.element = element;
hc2.distance = hc->distance;
if (currentNeighbors->length < m)
{
currentNeighbors->items[currentNeighbors->length++] = hc2;
/* Track updates */
if (updates != NULL)
*updates = lappend(*updates, CreateUpdate(hc, lc, currentNeighbors->length - 1));
}
else
{
/* Shrink connections */
HnswCandidate *pruned = NULL;
List *c = NIL;
/* Add and sort candidates */
for (int i = 0; i < currentNeighbors->length; i++)
c = lappend(c, &currentNeighbors->items[i]);
c = lappend(c, &hc2);
list_sort(c, CompareCandidateDistances);
/* Load elements on insert */
if (index != NULL)
{
for (int i = 0; i < currentNeighbors->length; i++)
{
if (currentNeighbors->items[i].element->vec == NULL)
{
HnswLoadElement(currentNeighbors->items[i].element, NULL, NULL, index, procinfo, collation, true);
/* Prune deleted element */
if (currentNeighbors->items[i].element->deleted)
{
pruned = &currentNeighbors->items[i];
break;
}
}
}
}
if (pruned == NULL)
{
SelectNeighbors(c, m, lc, procinfo, collation, &pruned);
/* Should not happen */
if (pruned == NULL)
continue;
}
/* Find and replace the pruned element */
for (int i = 0; i < currentNeighbors->length; i++)
{
if (currentNeighbors->items[i].element == pruned->element)
{
currentNeighbors->items[i] = hc2;
/* Track updates */
if (updates != NULL)
*updates = lappend(*updates, CreateUpdate(hc, lc, i));
break;
}
}
}
}
}
/*
* Algorithm 1 from paper
*/
HnswElement
HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, List **updates, bool vacuuming)
{
List *ep = NIL;
List *w;
int level = element->level;
int entryLevel;
List **newNeighbors = palloc(sizeof(List *) * (level + 1));
Datum q = PointerGetDatum(element->vec);
HnswElement dup;
BlockNumber *skipPage = vacuuming ? &element->neighborPage : NULL;
OffsetNumber *skipOffno = vacuuming ? &element->neighborOffno : NULL;
bool removeEntryPoint;
HnswCandidate *entryCandidate;
/* Get entry point and level */
if (entryPoint != NULL)
{
entryCandidate = HnswEntryCandidate(entryPoint, q, index, procinfo, collation, true);
ep = lappend(ep, entryCandidate);
entryLevel = entryPoint->level;
removeEntryPoint = vacuuming && list_length(entryPoint->heaptids) == 0;
}
else
{
entryLevel = -1;
removeEntryPoint = false;
}
/* 1st phase: greedy search to insert level */
for (int lc = entryLevel; lc >= level + 1; lc--)
{
w = HnswSearchLayer(q, ep, 1, lc, index, procinfo, collation, true, skipPage, skipOffno);
ep = w;
}
if (level > entryLevel)
level = entryLevel;
/* 2nd phase */
for (int lc = level; lc >= 0; lc--)
{
int lm = HnswGetLayerM(m, lc);
w = HnswSearchLayer(q, ep, efConstruction, lc, index, procinfo, collation, true, skipPage, skipOffno);
/* Remove entry point if it's being deleted */
if (removeEntryPoint)
w = list_delete_ptr(w, entryCandidate);
newNeighbors[lc] = SelectNeighbors(w, lm, lc, procinfo, collation, NULL);
ep = w;
}
/* Look for duplicate */
if (level >= 0 && !vacuuming)
{
dup = HnswFindDuplicate(element, newNeighbors[0]);
if (dup != NULL)
return dup;
}
/* Update connections */
for (int lc = level; lc >= 0; lc--)
{
int lm = HnswGetLayerM(m, lc);
AddConnections(element, newNeighbors[lc], lm, lc);
if (!vacuuming)
UpdateConnections(element, newNeighbors[lc], lm, lc, updates, index, procinfo, collation);
}
return NULL;
}

584
src/hnswvacuum.c Normal file
View File

@@ -0,0 +1,584 @@
#include "postgres.h"
#include <math.h>
#include "commands/vacuum.h"
#include "hnsw.h"
#include "storage/bufmgr.h"
#include "utils/memutils.h"
/*
* Check if deleted list contains an index tid
*/
static bool
DeletedContains(HTAB *deleted, ItemPointer indextid)
{
bool found;
hash_search(deleted, indextid, HASH_FIND, &found);
return found;
}
/*
* Remove deleted heap TIDs
*
* OK to remove for entry point, since always considered for searches and inserts
*/
static void
RemoveHeapTids(HnswVacuumState * vacuumstate)
{
BlockNumber blkno = HNSW_HEAD_BLKNO;
HnswElement highestPoint = &vacuumstate->highestPoint;
Relation index = vacuumstate->index;
BufferAccessStrategy bas = vacuumstate->bas;
HnswElement entryPoint = HnswGetEntryPoint(vacuumstate->index);
/* Store separately since highestPoint.level is uint8 */
int highestLevel = -1;
/* Initialize highest point */
highestPoint->blkno = InvalidBlockNumber;
highestPoint->offno = InvalidOffsetNumber;
while (BlockNumberIsValid(blkno))
{
Buffer buf;
Page page;
GenericXLogState *state;
OffsetNumber offno;
OffsetNumber maxoffno;
bool updated = false;
vacuum_delay_point();
buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
maxoffno = PageGetMaxOffsetNumber(page);
/* Iterate over nodes */
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
HnswElementTuple etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno));
int idx = 0;
bool itemUpdated = false;
/* Skip neighbor tuples */
if (!HnswIsElementTuple(etup))
continue;
if (ItemPointerIsValid(&etup->heaptids[0]))
{
for (int i = 0; i < HNSW_HEAPTIDS; i++)
{
/* Stop at first unused */
if (!ItemPointerIsValid(&etup->heaptids[i]))
break;
if (vacuumstate->callback(&etup->heaptids[i], vacuumstate->callback_state))
itemUpdated = true;
else
{
/* Move to front of list */
etup->heaptids[idx++] = etup->heaptids[i];
}
}
if (itemUpdated)
{
Size etupSize = HNSW_ELEMENT_TUPLE_SIZE(etup->vec.dim);
/* Mark rest as invalid */
for (int i = idx; i < HNSW_HEAPTIDS; i++)
ItemPointerSetInvalid(&etup->heaptids[i]);
if (!PageIndexTupleOverwrite(page, offno, (Item) etup, etupSize))
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
updated = true;
}
}
if (!ItemPointerIsValid(&etup->heaptids[0]))
{
ItemPointerData ip;
/* Add to deleted list */
ItemPointerSet(&ip, blkno, offno);
(void) hash_search(vacuumstate->deleted, &ip, HASH_ENTER, NULL);
}
else if (etup->level > highestLevel && !(blkno == entryPoint->blkno && offno == entryPoint->offno))
{
/* Keep track of highest non-entry point */
highestPoint->blkno = blkno;
highestPoint->offno = offno;
highestPoint->level = etup->level;
highestLevel = etup->level;
}
}
blkno = HnswPageGetOpaque(page)->nextblkno;
if (updated)
{
MarkBufferDirty(buf);
GenericXLogFinish(state);
}
else
GenericXLogAbort(state);
UnlockReleaseBuffer(buf);
}
}
/*
* Check for deleted neighbors
*/
static bool
NeedsUpdated(HnswVacuumState * vacuumstate, HnswElement element)
{
Relation index = vacuumstate->index;
BufferAccessStrategy bas = vacuumstate->bas;
Buffer buf;
Page page;
HnswNeighborTuple ntup;
bool needsUpdated = false;
buf = ReadBufferExtended(index, MAIN_FORKNUM, element->neighborPage, RBM_NORMAL, bas);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno));
Assert(HnswIsNeighborTuple(ntup));
/* Check neighbors */
for (int i = 0; i < ntup->count; i++)
{
HnswNeighborTupleItem *neighbor = &ntup->neighbors[i];
if (!ItemPointerIsValid(&neighbor->indextid))
continue;
/* Check if in deleted list */
if (DeletedContains(vacuumstate->deleted, &neighbor->indextid))
{
needsUpdated = true;
break;
}
}
UnlockReleaseBuffer(buf);
return needsUpdated;
}
/*
* Repair graph for a single element
*/
static void
RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element)
{
Relation index = vacuumstate->index;
Buffer buf;
Page page;
GenericXLogState *state;
int m = vacuumstate->m;
int efConstruction = vacuumstate->efConstruction;
FmgrInfo *procinfo = vacuumstate->procinfo;
Oid collation = vacuumstate->collation;
HnswElement entryPoint;
BufferAccessStrategy bas = vacuumstate->bas;
HnswNeighborTuple ntup = vacuumstate->ntup;
Size ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, m);
/* Check if any neighbors point to deleted values */
if (!NeedsUpdated(vacuumstate, element))
return;
/* Refresh entry point for each element */
entryPoint = HnswGetEntryPoint(index);
/* Special case for entry point */
if (element->blkno == entryPoint->blkno && element->offno == entryPoint->offno)
{
if (BlockNumberIsValid(vacuumstate->highestPoint.blkno))
{
/* Already updated */
if (vacuumstate->highestPoint.blkno == element->blkno && vacuumstate->highestPoint.offno == element->offno)
return;
entryPoint = &vacuumstate->highestPoint;
/* Reset neighbors from previous update */
entryPoint->neighbors = NULL;
}
else
entryPoint = NULL;
}
/* Init fields */
HnswInitNeighbors(element, m);
element->heaptids = NIL;
/* Add element to graph, skipping itself */
HnswInsertElement(element, entryPoint, index, procinfo, collation, m, efConstruction, NULL, true);
/* Update neighbor tuple */
/* Do this before getting page to minimize locking */
HnswSetNeighborTuple(ntup, element, m);
/* Get neighbor page */
buf = ReadBufferExtended(index, MAIN_FORKNUM, element->neighborPage, RBM_NORMAL, bas);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
/* Overwrite tuple */
if (!PageIndexTupleOverwrite(page, element->neighborOffno, (Item) ntup, ntupSize))
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
/* Commit */
MarkBufferDirty(buf);
GenericXLogFinish(state);
UnlockReleaseBuffer(buf);
}
/*
* Repair graph entry point
*/
static void
RepairGraphEntryPoint(HnswVacuumState * vacuumstate)
{
Relation index = vacuumstate->index;
HnswElement highestPoint = &vacuumstate->highestPoint;
HnswElement entryPoint;
MemoryContext oldCtx = MemoryContextSwitchTo(vacuumstate->tmpCtx);
/* Repair graph for highest non-entry point */
/* This may not be the highest with new inserts, but should be fine */
if (BlockNumberIsValid(highestPoint->blkno))
{
HnswLoadElement(highestPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true);
RepairGraphElement(vacuumstate, highestPoint);
}
/* See if entry point needs updated */
entryPoint = HnswGetEntryPoint(index);
if (entryPoint != NULL)
{
ItemPointerData epData;
ItemPointerSet(&epData, entryPoint->blkno, entryPoint->offno);
if (DeletedContains(vacuumstate->deleted, &epData))
HnswUpdateMetaPage(index, true, highestPoint, InvalidBlockNumber, MAIN_FORKNUM);
else
{
/* Highest point will be used to repair */
HnswLoadElement(entryPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true);
RepairGraphElement(vacuumstate, entryPoint);
}
}
/* Reset memory context */
MemoryContextSwitchTo(oldCtx);
MemoryContextReset(vacuumstate->tmpCtx);
}
/*
* Repair graph for all elements
*/
static void
RepairGraph(HnswVacuumState * vacuumstate)
{
Relation index = vacuumstate->index;
BufferAccessStrategy bas = vacuumstate->bas;
BlockNumber blkno = HNSW_HEAD_BLKNO;
RepairGraphEntryPoint(vacuumstate);
while (BlockNumberIsValid(blkno))
{
Buffer buf;
Page page;
OffsetNumber offno;
OffsetNumber maxoffno;
List *elements = NIL;
ListCell *lc2;
MemoryContext oldCtx;
vacuum_delay_point();
oldCtx = MemoryContextSwitchTo(vacuumstate->tmpCtx);
buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
maxoffno = PageGetMaxOffsetNumber(page);
/* Load items into memory to minimize locking */
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
HnswElementTuple etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno));
HnswElement element;
/* Skip neighbor tuples */
if (!HnswIsElementTuple(etup))
continue;
/* Skip updating neighbors if being deleted */
if (!ItemPointerIsValid(&etup->heaptids[0]))
continue;
/* Create an element */
element = palloc(sizeof(HnswElementData));
element->neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid);
element->neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid);
element->level = etup->level;
element->blkno = blkno;
element->offno = offno;
element->vec = palloc(VECTOR_SIZE(etup->vec.dim));
memcpy(element->vec, &etup->vec, VECTOR_SIZE(etup->vec.dim));
elements = lappend(elements, element);
}
blkno = HnswPageGetOpaque(page)->nextblkno;
UnlockReleaseBuffer(buf);
/* Update neighbor pages */
foreach(lc2, elements)
RepairGraphElement(vacuumstate, (HnswElement) lfirst(lc2));
/* Reset memory context */
MemoryContextSwitchTo(oldCtx);
MemoryContextReset(vacuumstate->tmpCtx);
}
}
/*
* Mark items as deleted
*/
static void
MarkDeleted(HnswVacuumState * vacuumstate)
{
BlockNumber blkno = HNSW_HEAD_BLKNO;
BlockNumber insertPage = InvalidBlockNumber;
Relation index = vacuumstate->index;
BufferAccessStrategy bas = vacuumstate->bas;
IndexBulkDeleteResult *stats = vacuumstate->stats;
while (BlockNumberIsValid(blkno))
{
Buffer buf;
Page page;
GenericXLogState *state;
OffsetNumber offno;
OffsetNumber maxoffno;
vacuum_delay_point();
buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas);
/*
* ambulkdelete cannot delete entries from pages that are pinned by
* other backends
*
* https://www.postgresql.org/docs/current/index-locking.html
*/
LockBufferForCleanup(buf);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
maxoffno = PageGetMaxOffsetNumber(page);
/* Update element and neighbors together */
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
HnswElementTuple etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno));
HnswNeighborTuple ntup;
Size etupSize;
Size ntupSize;
Buffer nbuf;
Page npage;
BlockNumber neighborPage;
OffsetNumber neighborOffno;
/* Skip neighbor tuples */
if (!HnswIsElementTuple(etup))
continue;
/* Skip deleted tuples */
if (etup->deleted)
continue;
/* Skip live tuples */
if (ItemPointerIsValid(&etup->heaptids[0]))
{
stats->num_index_tuples++;
continue;
}
/* Update stats */
stats->tuples_removed++;
/* Calculate sizes */
etupSize = HNSW_ELEMENT_TUPLE_SIZE(etup->vec.dim);
ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(etup->level, vacuumstate->m);
/* Get neighbor page */
neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid);
neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid);
if (neighborPage == blkno)
{
nbuf = buf;
npage = page;
}
else
{
nbuf = ReadBufferExtended(index, MAIN_FORKNUM, neighborPage, RBM_NORMAL, bas);
LockBuffer(nbuf, BUFFER_LOCK_EXCLUSIVE);
npage = GenericXLogRegisterBuffer(state, nbuf, 0);
}
ntup = (HnswNeighborTuple) PageGetItem(npage, PageGetItemId(npage, neighborOffno));
/* Overwrite element */
etup->deleted = 1;
MemSet(&etup->vec.x, 0, etup->vec.dim * sizeof(float));
/* Overwrite neighbors */
for (int i = 0; i < ntup->count; i++)
{
ItemPointerSetInvalid(&ntup->neighbors[i].indextid);
ntup->neighbors[i].distance = NAN;
}
/* Overwrite element tuple */
if (!PageIndexTupleOverwrite(page, offno, (Item) etup, etupSize))
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
/* Overwrite neighbor tuple */
if (!PageIndexTupleOverwrite(npage, neighborOffno, (Item) ntup, ntupSize))
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
/* Commit */
MarkBufferDirty(buf);
if (nbuf != buf)
MarkBufferDirty(nbuf);
GenericXLogFinish(state);
if (nbuf != buf)
UnlockReleaseBuffer(nbuf);
/* Set to first free page */
if (!BlockNumberIsValid(insertPage))
insertPage = blkno;
/* Prepare new xlog */
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
}
blkno = HnswPageGetOpaque(page)->nextblkno;
GenericXLogAbort(state);
UnlockReleaseBuffer(buf);
}
HnswUpdateMetaPage(index, false, NULL, insertPage, MAIN_FORKNUM);
}
/*
* Initialize the vacuum state
*/
static void
InitVacuumState(HnswVacuumState * vacuumstate, IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state)
{
Relation index = info->index;
HASHCTL hash_ctl;
if (stats == NULL)
stats = (IndexBulkDeleteResult *) palloc0(sizeof(IndexBulkDeleteResult));
vacuumstate->index = index;
vacuumstate->stats = stats;
vacuumstate->callback = callback;
vacuumstate->callback_state = callback_state;
vacuumstate->m = HnswGetM(index);
vacuumstate->efConstruction = HnswGetEfConstruction(index);
vacuumstate->bas = GetAccessStrategy(BAS_BULKREAD);
vacuumstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
vacuumstate->collation = index->rd_indcollation[0];
vacuumstate->ntup = palloc0(BLCKSZ);
vacuumstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext,
"Hnsw vacuum temporary context",
ALLOCSET_DEFAULT_SIZES);
/* Create hash table */
hash_ctl.keysize = sizeof(ItemPointerData);
hash_ctl.entrysize = sizeof(ItemPointerData);
hash_ctl.hcxt = CurrentMemoryContext;
vacuumstate->deleted = hash_create("hnswbulkdelete indextids", 256, &hash_ctl, HASH_ELEM | HASH_BLOBS | HASH_CONTEXT);
}
/*
* Free resources
*/
static void
FreeVacuumState(HnswVacuumState * vacuumstate)
{
hash_destroy(vacuumstate->deleted);
FreeAccessStrategy(vacuumstate->bas);
pfree(vacuumstate->ntup);
MemoryContextDelete(vacuumstate->tmpCtx);
}
/*
* Bulk delete tuples from the index
*/
IndexBulkDeleteResult *
hnswbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats,
IndexBulkDeleteCallback callback, void *callback_state)
{
HnswVacuumState vacuumstate;
InitVacuumState(&vacuumstate, info, stats, callback, callback_state);
/* Pass 1: Remove heap TIDs */
RemoveHeapTids(&vacuumstate);
/* Pass 2: Repair graph */
RepairGraph(&vacuumstate);
/* Pass 3: Mark as deleted */
MarkDeleted(&vacuumstate);
FreeVacuumState(&vacuumstate);
return vacuumstate.stats;
}
/*
* Clean up after a VACUUM operation
*/
IndexBulkDeleteResult *
hnswvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats)
{
Relation rel = info->index;
if (info->analyze_only)
return stats;
/* stats is NULL if ambulkdelete not called */
/* OK to return NULL if index not changed */
if (stats == NULL)
return NULL;
stats->num_pages = RelationGetNumberOfBlocks(rel);
return stats;
}

View File

@@ -3,10 +3,6 @@
#include "postgres.h"
#if PG_VERSION_NUM < 110000
#error "Requires PostgreSQL 11+"
#endif
#include "access/generic_xlog.h"
#include "access/parallel.h"
#include "access/reloptions.h"

View File

@@ -4,6 +4,7 @@
#include "catalog/pg_type.h"
#include "fmgr.h"
#include "hnsw.h"
#include "ivfflat.h"
#include "lib/stringinfo.h"
#include "libpq/pqformat.h"
@@ -37,6 +38,7 @@ PG_MODULE_MAGIC;
void
_PG_init(void)
{
HnswInit();
IvfflatInit();
}

View File

@@ -0,0 +1,26 @@
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING hnsw (val vector_cosine_ops);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <=> '[3,3,3]';
val
---------
[1,1,1]
[1,2,3]
[1,2,4]
(3 rows)
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2;
count
-------
3
(1 row)
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2;
count
-------
3
(1 row)
DROP TABLE t;

21
test/expected/hnsw_ip.out Normal file
View File

@@ -0,0 +1,21 @@
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING hnsw (val vector_ip_ops);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <#> '[3,3,3]';
val
---------
[1,2,4]
[1,2,3]
[1,1,1]
[0,0,0]
(4 rows)
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector)) t2;
count
-------
4
(1 row)
DROP TABLE t;

30
test/expected/hnsw_l2.out Normal file
View File

@@ -0,0 +1,30 @@
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING hnsw (val vector_l2_ops);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <-> '[3,3,3]';
val
---------
[1,2,3]
[1,2,4]
[1,1,1]
[0,0,0]
(4 rows)
SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector);
val
---------
[0,0,0]
[1,1,1]
[1,2,3]
[1,2,4]
(4 rows)
SELECT COUNT(*) FROM t;
count
-------
5
(1 row)
DROP TABLE t;

View File

@@ -0,0 +1,25 @@
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 3);
ERROR: value 3 out of bounds for option "m"
DETAIL: Valid values are between "4" and "100".
CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 101);
ERROR: value 101 out of bounds for option "m"
DETAIL: Valid values are between "4" and "100".
CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 9);
ERROR: value 9 out of bounds for option "ef_construction"
DETAIL: Valid values are between "10" and "1000".
CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 1001);
ERROR: value 1001 out of bounds for option "ef_construction"
DETAIL: Valid values are between "10" and "1000".
SHOW hnsw.ef_search;
hnsw.ef_search
----------------
40
(1 row)
SET hnsw.ef_search = 9;
ERROR: 9 is outside the valid range for parameter "hnsw.ef_search" (10 .. 1000)
SET hnsw.ef_search = 1001;
ERROR: 1001 is outside the valid range for parameter "hnsw.ef_search" (10 .. 1000)
DROP TABLE t;

View File

@@ -0,0 +1,13 @@
SET enable_seqscan = off;
CREATE UNLOGGED TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING hnsw (val vector_l2_ops);
SELECT * FROM t ORDER BY val <-> '[3,3,3]';
val
---------
[1,2,3]
[1,1,1]
[0,0,0]
(3 rows)
DROP TABLE t;

13
test/sql/hnsw_cosine.sql Normal file
View File

@@ -0,0 +1,13 @@
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING hnsw (val vector_cosine_ops);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <=> '[3,3,3]';
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2;
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2;
DROP TABLE t;

12
test/sql/hnsw_ip.sql Normal file
View File

@@ -0,0 +1,12 @@
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING hnsw (val vector_ip_ops);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <#> '[3,3,3]';
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector)) t2;
DROP TABLE t;

13
test/sql/hnsw_l2.sql Normal file
View File

@@ -0,0 +1,13 @@
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING hnsw (val vector_l2_ops);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <-> '[3,3,3]';
SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector);
SELECT COUNT(*) FROM t;
DROP TABLE t;

14
test/sql/hnsw_options.sql Normal file
View File

@@ -0,0 +1,14 @@
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 3);
CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 101);
CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 9);
CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 1001);
SHOW hnsw.ef_search;
SET hnsw.ef_search = 9;
SET hnsw.ef_search = 1001;
DROP TABLE t;

View File

@@ -0,0 +1,9 @@
SET enable_seqscan = off;
CREATE UNLOGGED TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING hnsw (val vector_l2_ops);
SELECT * FROM t ORDER BY val <-> '[3,3,3]';
DROP TABLE t;

99
test/t/010_hnsw_wal.pl Normal file
View File

@@ -0,0 +1,99 @@
# Based on postgres/contrib/bloom/t/001_wal.pl
# Test generic xlog record work for hnsw index replication.
use strict;
use warnings;
use PostgresNode;
use TestLib;
use Test::More;
my $dim = 32;
my $node_primary;
my $node_replica;
# Run few queries on both primary and replica and check their results match.
sub test_index_replay
{
my ($test_name) = @_;
# Wait for replica to catch up
my $applname = $node_replica->name;
my $server_version_num = $node_primary->safe_psql("postgres", "SHOW server_version_num");
my $caughtup_query = "SELECT pg_current_wal_lsn() <= replay_lsn FROM pg_stat_replication WHERE application_name = '$applname';";
$node_primary->poll_query_until('postgres', $caughtup_query)
or die "Timed out while waiting for replica 1 to catch up";
my @r = ();
for (1 .. $dim) {
push(@r, rand());
}
my $sql = join(",", @r);
my $queries = qq(
SET enable_seqscan = off;
SELECT * FROM tst ORDER BY v <-> '[$sql]' LIMIT 10;
);
# Run test queries and compare their result
my $primary_result = $node_primary->safe_psql("postgres", $queries);
my $replica_result = $node_replica->safe_psql("postgres", $queries);
is($primary_result, $replica_result, "$test_name: query result matches");
return;
}
# Use ARRAY[random(), random(), random(), ...] over
# SELECT array_agg(random()) FROM generate_series(1, $dim)
# to generate different values for each row
my $array_sql = join(",", ('random()') x $dim);
# Initialize primary node
$node_primary = get_new_node('primary');
$node_primary->init(allows_streaming => 1);
if ($dim > 32) {
# TODO use wal_keep_segments for Postgres < 13
$node_primary->append_conf('postgresql.conf', qq(wal_keep_size = 1GB));
}
if ($dim > 1500) {
$node_primary->append_conf('postgresql.conf', qq(maintenance_work_mem = 128MB));
}
$node_primary->start;
my $backup_name = 'my_backup';
# Take backup
$node_primary->backup($backup_name);
# Create streaming replica linking to primary
$node_replica = get_new_node('replica');
$node_replica->init_from_backup($node_primary, $backup_name,
has_streaming => 1);
$node_replica->start;
# Create hnsw index on primary
$node_primary->safe_psql("postgres", "CREATE EXTENSION vector;");
$node_primary->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));");
$node_primary->safe_psql("postgres",
"INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 1000) i;"
);
$node_primary->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);");
# Test that queries give same result
test_index_replay('initial');
# Run 10 cycles of table modification. Run test queries after each modification.
for my $i (1 .. 10)
{
$node_primary->safe_psql("postgres", "DELETE FROM tst WHERE i = $i;");
test_index_replay("delete $i");
$node_primary->safe_psql("postgres", "VACUUM tst;");
test_index_replay("vacuum $i");
my ($start, $end) = (1001 + ($i - 1) * 100, 1000 + $i * 100);
$node_primary->safe_psql("postgres",
"INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series($start, $end) i;"
);
test_index_replay("insert $i");
}
done_testing();

43
test/t/011_hnsw_vacuum.pl Normal file
View File

@@ -0,0 +1,43 @@
use strict;
use warnings;
use PostgresNode;
use TestLib;
use Test::More;
my $dim = 3;
my @r = ();
for (1 .. $dim) {
my $v = int(rand(1000)) + 1;
push(@r, "i % $v");
}
my $array_sql = join(", ", @r);
# Initialize node
my $node = get_new_node('node');
$node->init;
$node->start;
# Create table and index
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 10000) i;"
);
$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);");
# Get size
my $size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');");
# Delete all, vacuum, and insert same data
$node->safe_psql("postgres", "DELETE FROM tst;");
$node->safe_psql("postgres", "VACUUM tst;");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 10000) i;"
);
# Check size
my $new_size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');");
cmp_ok($new_size, "<=", $size * 1.01, "size does not increase too much");
done_testing();

View File

@@ -0,0 +1,96 @@
use strict;
use warnings;
use PostgresNode;
use TestLib;
use Test::More;
my $node;
my @queries = ();
my @expected;
my $limit = 20;
sub test_recall
{
my ($min, $operator) = @_;
my $correct = 0;
my $total = 0;
my $explain = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit;
));
like($explain, qr/Index Scan/);
for my $i (0 .. $#queries) {
my $actual = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit;
));
my @actual_ids = split("\n", $actual);
my %actual_set = map { $_ => 1 } @actual_ids;
my @expected_ids = split("\n", $expected[$i]);
foreach (@expected_ids) {
if (exists($actual_set{$_})) {
$correct++;
}
$total++;
}
}
cmp_ok($correct / $total, ">=", $min, $operator);
}
# Initialize node
$node = get_new_node('node');
$node->init;
$node->start;
# Create table
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;"
);
# Generate queries
for (1..20) {
my $r1 = rand();
my $r2 = rand();
my $r3 = rand();
push(@queries, "[$r1,$r2,$r3]");
}
# Check each index type
my @operators = ("<->", "<#>", "<=>");
foreach (@operators) {
my $operator = $_;
# Get exact results
@expected = ();
foreach (@queries) {
my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;");
push(@expected, $res);
}
# Add index
my $opclass;
if ($operator eq "<->") {
$opclass = "vector_l2_ops";
} elsif ($operator eq "<#>") {
$opclass = "vector_ip_ops";
} else {
$opclass = "vector_cosine_ops";
}
$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v $opclass);");
if ($operator eq "<#>") {
test_recall(0.80, $operator);
} else {
test_recall(0.99, $operator);
}
}
done_testing();

View File

@@ -0,0 +1,103 @@
use strict;
use warnings;
use PostgresNode;
use TestLib;
use Test::More;
my $node;
my @queries = ();
my @expected;
my $limit = 20;
sub test_recall
{
my ($min, $operator) = @_;
my $correct = 0;
my $total = 0;
my $explain = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit;
));
like($explain, qr/Index Scan/);
for my $i (0 .. $#queries) {
my $actual = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit;
));
my @actual_ids = split("\n", $actual);
my %actual_set = map { $_ => 1 } @actual_ids;
my @expected_ids = split("\n", $expected[$i]);
foreach (@expected_ids) {
if (exists($actual_set{$_})) {
$correct++;
}
$total++;
}
}
cmp_ok($correct / $total, ">=", $min, $operator);
}
# Initialize node
$node = get_new_node('node');
$node->init;
$node->start;
# Create table
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));");
# Generate queries
for (1..20) {
my $r1 = rand();
my $r2 = rand();
my $r3 = rand();
push(@queries, "[$r1,$r2,$r3]");
}
# Check each index type
my @operators = ("<->", "<#>", "<=>");
foreach (@operators) {
my $operator = $_;
# Add index
my $opclass;
if ($operator eq "<->") {
$opclass = "vector_l2_ops";
} elsif ($operator eq "<#>") {
$opclass = "vector_ip_ops";
} else {
$opclass = "vector_cosine_ops";
}
$node->safe_psql("postgres", "CREATE INDEX idx ON tst USING hnsw (v $opclass);");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;"
);
# Get exact results
@expected = ();
foreach (@queries) {
my $res = $node->safe_psql("postgres", qq(
SET enable_indexscan = off;
SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;
));
push(@expected, $res);
}
if ($operator eq "<#>") {
test_recall(0.80, $operator);
} else {
test_recall(0.99, $operator);
}
$node->safe_psql("postgres", "DROP INDEX idx;");
$node->safe_psql("postgres", "TRUNCATE tst;");
}
done_testing();

View File

@@ -0,0 +1,58 @@
use strict;
use warnings;
use PostgresNode;
use TestLib;
use Test::More;
# Ensures elements and neighbors on both same and different pages
my $dim = 1900;
my $array_sql = join(",", ('random()') x $dim);
# Initialize node
my $node = get_new_node('node');
$node->init;
$node->start;
# Create table and index
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
$node->safe_psql("postgres", "CREATE TABLE tst (v vector($dim));");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT ARRAY[$array_sql] FROM generate_series(1, 100) i;"
);
$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);");
$node->pgbench(
"--no-vacuum --client=5 --transactions=100",
0,
[qr{actually processed}],
[qr{^$}],
"concurrent INSERTs",
{
"007_inserts" => "INSERT INTO tst SELECT ARRAY[$array_sql] FROM generate_series(1, 10) i;"
}
);
sub idx_scan
{
# Stats do not update instantaneously
# https://www.postgresql.org/docs/current/monitoring-stats.html#MONITORING-STATS-VIEWS
sleep(1);
$node->safe_psql("postgres", "SELECT idx_scan FROM pg_stat_user_indexes WHERE indexrelid = 'tst_v_idx'::regclass;");
}
my $expected = 100 + 5 * 100 * 10;
my $count = $node->safe_psql("postgres", "SELECT COUNT(*) FROM tst;");
is($count, $expected);
is(idx_scan(), 0);
$count = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
SET hnsw.ef_search = 400;
SELECT COUNT(*) FROM (SELECT v FROM tst ORDER BY v <-> (SELECT v FROM tst LIMIT 1)) t;
));
is($count, 400);
is(idx_scan(), 1);
done_testing();