Added HNSW index type

This commit is contained in:
Andrew Kane
2023-07-21 16:25:39 -07:00
parent f210791846
commit f0760eee76
28 changed files with 3625 additions and 7 deletions

View File

@@ -3,7 +3,7 @@ EXTVERSION = 0.4.4
MODULE_big = vector
DATA = $(wildcard sql/*--*.sql)
OBJS = src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o
OBJS = src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o
TESTS = $(wildcard test/sql/*.sql)
REGRESS = $(patsubst test/sql/%.sql,%,$(TESTS))

View File

@@ -1,7 +1,7 @@
EXTENSION = vector
EXTVERSION = 0.4.4
OBJS = src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj
OBJS = src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj
REGRESS = btree cast copy functions input ivfflat_cosine ivfflat_ip ivfflat_l2 ivfflat_options ivfflat_unlogged
REGRESS_OPTS = --inputdir=test --load-extension=vector

View File

@@ -18,3 +18,26 @@ CREATE AGGREGATE sum(vector) (
COMBINEFUNC = vector_add,
PARALLEL = SAFE
);
CREATE FUNCTION hnswhandler(internal) RETURNS index_am_handler
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnswhandler;
COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method';
CREATE OPERATOR CLASS vector_l2_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_l2_squared_distance(vector, vector);
CREATE OPERATOR CLASS vector_ip_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector);
CREATE OPERATOR CLASS vector_cosine_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector),
FUNCTION 2 vector_norm(vector);

View File

@@ -227,7 +227,7 @@ CREATE OPERATOR > (
RESTRICT = scalargtsel, JOIN = scalargtjoinsel
);
-- access method
-- access methods
CREATE FUNCTION ivfflathandler(internal) RETURNS index_am_handler
AS 'MODULE_PATHNAME' LANGUAGE C;
@@ -236,6 +236,13 @@ CREATE ACCESS METHOD ivfflat TYPE INDEX HANDLER ivfflathandler;
COMMENT ON ACCESS METHOD ivfflat IS 'ivfflat index access method';
CREATE FUNCTION hnswhandler(internal) RETURNS index_am_handler
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnswhandler;
COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method';
-- opclasses
CREATE OPERATOR CLASS vector_ops
@@ -267,3 +274,19 @@ CREATE OPERATOR CLASS vector_cosine_ops
FUNCTION 2 vector_norm(vector),
FUNCTION 3 vector_spherical_distance(vector, vector),
FUNCTION 4 vector_norm(vector);
CREATE OPERATOR CLASS vector_l2_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_l2_squared_distance(vector, vector);
CREATE OPERATOR CLASS vector_ip_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector);
CREATE OPERATOR CLASS vector_cosine_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector),
FUNCTION 2 vector_norm(vector);

210
src/hnsw.c Normal file
View File

@@ -0,0 +1,210 @@
#include "postgres.h"
#include <float.h>
#include "access/amapi.h"
#include "commands/vacuum.h"
#include "hnsw.h"
#include "utils/guc.h"
#include "utils/selfuncs.h"
#if PG_VERSION_NUM >= 120000
#include "commands/progress.h"
#endif
int hnsw_ef_search;
static relopt_kind hnsw_relopt_kind;
/*
* Initialize index options and variables
*/
void
HnswInit(void)
{
hnsw_relopt_kind = add_reloption_kind();
add_int_reloption(hnsw_relopt_kind, "m", "Number of connections",
HNSW_DEFAULT_M, HNSW_MIN_M, HNSW_MAX_M
#if PG_VERSION_NUM >= 130000
,AccessExclusiveLock
#endif
);
add_int_reloption(hnsw_relopt_kind, "ef_construction", "Size of dynamic candidate list",
HNSW_DEFAULT_EF_CONSTRUCTION, HNSW_MIN_EF_CONSTRUCTION, HNSW_MAX_EF_CONSTRUCTION
#if PG_VERSION_NUM >= 130000
,AccessExclusiveLock
#endif
);
DefineCustomIntVariable("hnsw.ef_search", "Sets the size of dynamic candidate list",
"Valid range is 10..1000.", &hnsw_ef_search,
HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL);
}
/*
* Get the name of index build phase
*/
#if PG_VERSION_NUM >= 120000
static char *
hnswbuildphasename(int64 phasenum)
{
switch (phasenum)
{
case PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE:
return "initializing";
case PROGRESS_HNSW_PHASE_LOAD:
return "loading tuples";
default:
return NULL;
}
}
#endif
/*
* Estimate the cost of an index scan
*/
static void
hnswcostestimate(PlannerInfo *root, IndexPath *path, double loop_count,
Cost *indexStartupCost, Cost *indexTotalCost,
Selectivity *indexSelectivity, double *indexCorrelation,
double *indexPages)
{
GenericCosts costs;
#if PG_VERSION_NUM < 120000
List *qinfos;
#endif
/* Never use index without order */
if (path->indexorderbys == NULL)
{
*indexStartupCost = DBL_MAX;
*indexTotalCost = DBL_MAX;
*indexSelectivity = 0;
*indexCorrelation = 0;
*indexPages = 0;
return;
}
MemSet(&costs, 0, sizeof(costs));
#if PG_VERSION_NUM >= 120000
genericcostestimate(root, path, loop_count, &costs);
#else
qinfos = deconstruct_indexquals(path);
genericcostestimate(root, path, loop_count, qinfos, &costs);
#endif
/* TODO Improve cost estimate */
*indexStartupCost = costs.indexStartupCost;
*indexTotalCost = costs.indexTotalCost;
*indexSelectivity = costs.indexSelectivity;
*indexCorrelation = costs.indexCorrelation;
*indexPages = costs.numIndexPages;
}
/*
* Parse and validate the reloptions
*/
static bytea *
hnswoptions(Datum reloptions, bool validate)
{
static const relopt_parse_elt tab[] = {
{"m", RELOPT_TYPE_INT, offsetof(HnswOptions, m)},
{"ef_construction", RELOPT_TYPE_INT, offsetof(HnswOptions, efConstruction)},
};
#if PG_VERSION_NUM >= 130000
return (bytea *) build_reloptions(reloptions, validate,
hnsw_relopt_kind,
sizeof(HnswOptions),
tab, lengthof(tab));
#else
relopt_value *options;
int numoptions;
HnswOptions *rdopts;
options = parseRelOptions(reloptions, validate, hnsw_relopt_kind, &numoptions);
rdopts = allocateReloptStruct(sizeof(HnswOptions), options, numoptions);
fillRelOptions((void *) rdopts, sizeof(HnswOptions), options, numoptions,
validate, tab, lengthof(tab));
return (bytea *) rdopts;
#endif
}
/*
* Validate catalog entries for the specified operator class
*/
static bool
hnswvalidate(Oid opclassoid)
{
return true;
}
/*
* Define index handler
*
* See https://www.postgresql.org/docs/current/index-api.html
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(hnswhandler);
Datum
hnswhandler(PG_FUNCTION_ARGS)
{
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
amroutine->amstrategies = 0;
amroutine->amsupport = 2;
#if PG_VERSION_NUM >= 130000
amroutine->amoptsprocnum = 0;
#endif
amroutine->amcanorder = false;
amroutine->amcanorderbyop = true;
amroutine->amcanbackward = false; /* can change direction mid-scan */
amroutine->amcanunique = false;
amroutine->amcanmulticol = false;
amroutine->amoptionalkey = true;
amroutine->amsearcharray = false;
amroutine->amsearchnulls = false;
amroutine->amstorage = false;
amroutine->amclusterable = false;
amroutine->ampredlocks = false;
amroutine->amcanparallel = false;
amroutine->amcaninclude = false;
#if PG_VERSION_NUM >= 130000
amroutine->amusemaintenanceworkmem = false; /* not used during VACUUM */
amroutine->amparallelvacuumoptions = VACUUM_OPTION_PARALLEL_BULKDEL;
#endif
amroutine->amkeytype = InvalidOid;
/* Interface functions */
amroutine->ambuild = hnswbuild;
amroutine->ambuildempty = hnswbuildempty;
amroutine->aminsert = hnswinsert;
amroutine->ambulkdelete = hnswbulkdelete;
amroutine->amvacuumcleanup = hnswvacuumcleanup;
amroutine->amcanreturn = NULL; /* tuple not included in heapsort */
amroutine->amcostestimate = hnswcostestimate;
amroutine->amoptions = hnswoptions;
amroutine->amproperty = NULL; /* TODO AMPROP_DISTANCE_ORDERABLE */
#if PG_VERSION_NUM >= 120000
amroutine->ambuildphasename = hnswbuildphasename;
#endif
amroutine->amvalidate = hnswvalidate;
#if PG_VERSION_NUM >= 140000
amroutine->amadjustmembers = NULL;
#endif
amroutine->ambeginscan = hnswbeginscan;
amroutine->amrescan = hnswrescan;
amroutine->amgettuple = hnswgettuple;
amroutine->amgetbitmap = NULL;
amroutine->amendscan = hnswendscan;
amroutine->ammarkpos = NULL;
amroutine->amrestrpos = NULL;
/* Interface functions to support parallel index scans */
amroutine->amestimateparallelscan = NULL;
amroutine->aminitparallelscan = NULL;
amroutine->amparallelrescan = NULL;
PG_RETURN_POINTER(amroutine);
}

271
src/hnsw.h Normal file
View File

@@ -0,0 +1,271 @@
#ifndef HNSW_H
#define HNSW_H
#include "postgres.h"
#include "access/generic_xlog.h"
#include "access/reloptions.h"
#include "nodes/execnodes.h"
#include "port.h" /* for random() */
#include "utils/sampling.h"
#include "vector.h"
#if PG_VERSION_NUM < 110000
#error "Requires PostgreSQL 11+"
#endif
#define HNSW_MAX_DIM 2000
/* Support functions */
#define HNSW_DISTANCE_PROC 1
#define HNSW_NORM_PROC 2
#define HNSW_VERSION 1
#define HNSW_MAGIC_NUMBER 0xA953A953
#define HNSW_PAGE_ID 0xFF85
/* Preserved page numbers */
#define HNSW_METAPAGE_BLKNO 0
#define HNSW_HEAD_BLKNO 1 /* first element page */
#define HNSW_DEFAULT_M 16
#define HNSW_MIN_M 4
#define HNSW_MAX_M 100
#define HNSW_DEFAULT_EF_CONSTRUCTION 40
#define HNSW_MIN_EF_CONSTRUCTION 10
#define HNSW_MAX_EF_CONSTRUCTION 1000
#define HNSW_DEFAULT_EF_SEARCH 40
#define HNSW_MIN_EF_SEARCH 10
#define HNSW_MAX_EF_SEARCH 1000
#define HNSW_HEAPTIDS 10
/* Build phases */
/* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */
#define PROGRESS_HNSW_PHASE_LOAD 2
#define HNSW_ELEMENT_TUPLE_SIZE(_dim) (offsetof(HnswElementTupleData, vec) + VECTOR_SIZE(_dim))
#define HnswPageGetOpaque(page) ((HnswPageOpaque) PageGetSpecialPointer(page))
#define HnswPageGetMeta(page) ((HnswMetaPageData *) PageGetContents(page))
#if PG_VERSION_NUM >= 150000
#define RandomDouble() pg_prng_double(&pg_global_prng_state)
#else
#define RandomDouble() (((double) random()) / MAX_RANDOM_VALUE)
#endif
#if PG_VERSION_NUM < 130000
#define list_delete_last(list) list_truncate(list, list_length(list) - 1)
#define list_sort(list, cmp) list_qsort(list, cmp)
#endif
#define GetLayerM(m, layer) (layer == 0 ? m * 2 : m)
#define HnswGetMl(m) (1 / log(m))
/* Variables */
extern int hnsw_ef_search;
typedef struct HnswNeighborArray HnswNeighborArray;
typedef struct HnswElementData
{
List *heaptids;
uint8 level;
uint8 deleted;
HnswNeighborArray *neighbors;
BlockNumber blkno;
OffsetNumber offno;
BlockNumber neighborPage;
Vector *vec;
} HnswElementData;
typedef HnswElementData * HnswElement;
typedef struct HnswCandidate
{
HnswElement element;
float distance;
} HnswCandidate;
typedef struct HnswNeighborArray
{
int length;
HnswCandidate *items;
} HnswNeighborArray;
typedef struct HnswUpdate
{
HnswCandidate hc;
int level;
int index;
} HnswUpdate;
typedef struct HnswPairingHeapNode
{
pairingheap_node ph_node;
HnswCandidate *inner;
} HnswPairingHeapNode;
/* HNSW index options */
typedef struct HnswOptions
{
int32 vl_len_; /* varlena header (do not touch directly!) */
int m; /* number of connections */
int efConstruction; /* size of dynamic candidate list */
} HnswOptions;
typedef struct HnswBuildState
{
/* Info */
Relation heap;
Relation index;
IndexInfo *indexInfo;
ForkNumber forkNum;
/* Settings */
int dimensions;
int m;
int efConstruction;
/* Statistics */
double indtuples;
double reltuples;
/* Support functions */
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
Oid collation;
/* Variables */
List *elements;
HnswElement entryPoint;
double ml;
int maxLevel;
double maxInMemoryElements;
bool flushed;
Vector *normvec;
/* Memory */
MemoryContext tmpCtx;
} HnswBuildState;
typedef struct HnswMetaPageData
{
uint32 magicNumber;
uint32 version;
uint32 dimensions;
uint32 m;
uint32 efConstruction;
BlockNumber entryBlkno;
OffsetNumber entryOffno;
BlockNumber insertPage;
} HnswMetaPageData;
typedef HnswMetaPageData * HnswMetaPage;
typedef struct HnswPageOpaqueData
{
BlockNumber nextblkno;
uint16 unused;
uint16 page_id; /* for identification of HNSW indexes */
} HnswPageOpaqueData;
typedef HnswPageOpaqueData * HnswPageOpaque;
typedef struct HnswElementTupleData
{
ItemPointerData heaptids[HNSW_HEAPTIDS];
uint8 level;
uint8 deleted;
uint16 unused;
BlockNumber neighborPage;
Vector vec;
} HnswElementTupleData;
typedef HnswElementTupleData * HnswElementTuple;
typedef struct HnswNeighborTupleData
{
ItemPointerData indextid;
uint16 unused;
float distance;
} HnswNeighborTupleData;
typedef HnswNeighborTupleData * HnswNeighborTuple;
typedef struct HnswScanOpaqueData
{
bool first;
Buffer buf;
List *w;
/* Support functions */
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
Oid collation;
} HnswScanOpaqueData;
typedef HnswScanOpaqueData * HnswScanOpaque;
typedef struct HnswVacuumState
{
Relation index;
IndexBulkDeleteResult *stats;
IndexBulkDeleteCallback callback;
void *callback_state;
int m;
int efConstruction;
HTAB *deleted;
BufferAccessStrategy bas;
FmgrInfo *procinfo;
Oid collation;
HnswNeighborTuple ntup;
Size nsize;
HnswElementData highestPoint;
MemoryContext tmpCtx;
} HnswVacuumState;
/* Methods */
int HnswGetM(Relation index);
int HnswGetEfConstruction(Relation index);
FmgrInfo *HnswOptionalProcInfo(Relation rel, uint16 procnum);
bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result);
void HnswCommitBuffer(Buffer buf, GenericXLogState *state);
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
void HnswInitPage(Buffer buf, Page page);
void HnswInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state);
void HnswInit(void);
List *SearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, bool inserting, BlockNumber *skipPage);
HnswElement GetEntryPoint(Relation index);
HnswElement HnswInitElement(ItemPointer tid, int m, double ml, int maxLevel);
void HnswFreeElement(HnswElement element);
HnswElement HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, List **updates, bool vacuuming);
HnswCandidate *EntryCandidate(HnswElement em, Datum q, Relation rel, FmgrInfo *procinfo, Oid collation, bool loadvec);
void UpdateMetaPage(Relation index, bool updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum);
void AddNeighborsToPage(Relation index, Page page, HnswElement e, HnswNeighborTuple neighbor, Size neighborsz, int m);
void HnswAddHeapTid(HnswElement element, ItemPointer heaptid);
void HnswInitNeighbors(HnswElement element, int m);
bool HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel);
void HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadvec);
/* Index access methods */
IndexBuildResult *hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo);
void hnswbuildempty(Relation index);
bool hnswinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heap, IndexUniqueCheck checkUnique
#if PG_VERSION_NUM >= 140000
,bool indexUnchanged
#endif
,IndexInfo *indexInfo
);
IndexBulkDeleteResult *hnswbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state);
IndexBulkDeleteResult *hnswvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats);
IndexScanDesc hnswbeginscan(Relation index, int nkeys, int norderbys);
void hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys);
bool hnswgettuple(IndexScanDesc scan, ScanDirection dir);
void hnswendscan(IndexScanDesc scan);
/* Ensure fits in uint8 */
#define HnswGetMaxLevel(m) Min(((BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(HnswPageOpaqueData))) / (MAXALIGN(sizeof(HnswNeighborTupleData)) + sizeof(ItemIdData)) / m) - 2, 255)
#endif

484
src/hnswbuild.c Normal file
View File

@@ -0,0 +1,484 @@
#include "postgres.h"
#include <math.h>
#include "catalog/index.h"
#include "hnsw.h"
#include "miscadmin.h"
#include "lib/pairingheap.h"
#include "nodes/pg_list.h"
#include "storage/bufmgr.h"
#include "utils/memutils.h"
#if PG_VERSION_NUM >= 140000
#include "utils/backend_progress.h"
#elif PG_VERSION_NUM >= 120000
#include "pgstat.h"
#endif
#if PG_VERSION_NUM >= 120000
#include "access/tableam.h"
#include "commands/progress.h"
#else
#define PROGRESS_CREATEIDX_TUPLES_DONE 0
#endif
#if PG_VERSION_NUM >= 130000
#define CALLBACK_ITEM_POINTER ItemPointer tid
#else
#define CALLBACK_ITEM_POINTER HeapTuple hup
#endif
#if PG_VERSION_NUM >= 120000
#define UpdateProgress(index, val) pgstat_progress_update_param(index, val)
#else
#define UpdateProgress(index, val) ((void)val)
#endif
/*
* Create the metapage
*/
static void
CreateMetaPage(HnswBuildState * buildstate)
{
Relation index = buildstate->index;
ForkNumber forkNum = buildstate->forkNum;
Buffer buf;
Page page;
GenericXLogState *state;
HnswMetaPage metap;
buf = HnswNewBuffer(index, forkNum);
HnswInitRegisterPage(index, &buf, &page, &state);
/* Set metapage data */
metap = HnswPageGetMeta(page);
metap->magicNumber = HNSW_MAGIC_NUMBER;
metap->version = HNSW_VERSION;
metap->dimensions = buildstate->dimensions;
metap->m = buildstate->m;
metap->efConstruction = buildstate->efConstruction;
metap->entryBlkno = InvalidBlockNumber;
metap->entryOffno = InvalidOffsetNumber;
metap->insertPage = InvalidBlockNumber;
((PageHeader) page)->pd_lower =
((char *) metap + sizeof(HnswMetaPageData)) - (char *) page;
HnswCommitBuffer(buf, state);
}
/*
* Create element pages
*/
static void
CreateElementPages(HnswBuildState * buildstate)
{
Relation index = buildstate->index;
ForkNumber forkNum = buildstate->forkNum;
int dimensions = buildstate->dimensions;
Size elementsz;
HnswElementTuple element;
int elementsPerPage;
BlockNumber neighborPage;
BlockNumber insertPage;
Buffer buf;
Page page;
GenericXLogState *state;
ListCell *lc;
/* Allocate once */
elementsz = MAXALIGN(HNSW_ELEMENT_TUPLE_SIZE(dimensions));
element = palloc0(elementsz);
/* Calculate starting neighbor page */
elementsPerPage = (BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(HnswPageOpaqueData))) / (elementsz + sizeof(ItemIdData));
neighborPage = HNSW_HEAD_BLKNO + (int) ceil(list_length(buildstate->elements) / (double) elementsPerPage);
/* Prepare first page */
buf = HnswNewBuffer(index, forkNum);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, GENERIC_XLOG_FULL_IMAGE);
HnswInitPage(buf, page);
foreach(lc, buildstate->elements)
{
HnswElement e = lfirst(lc);
/* Calculate neighbor page */
/* Will be rechecked later */
e->neighborPage = neighborPage++;
/* Set item data */
for (int i = 0; i < HNSW_HEAPTIDS; i++)
{
if (i < list_length(e->heaptids))
element->heaptids[i] = *((ItemPointer) list_nth(e->heaptids, i));
else
ItemPointerSetInvalid(&element->heaptids[i]);
}
element->level = e->level;
element->deleted = 0;
element->neighborPage = e->neighborPage;
memcpy(&element->vec, e->vec, VECTOR_SIZE(dimensions));
/* Ensure free space */
if (PageGetFreeSpace(page) < elementsz)
{
/* Add a new page */
Buffer newbuf = HnswNewBuffer(index, forkNum);
/* Update previous page */
HnswPageGetOpaque(page)->nextblkno = BufferGetBlockNumber(newbuf);
/* Commit */
MarkBufferDirty(buf);
GenericXLogFinish(state);
UnlockReleaseBuffer(buf);
/* Can take a while, so ensure we can interrupt */
/* Needs to be called when no buffer locks are held */
CHECK_FOR_INTERRUPTS();
/* Prepare new page */
buf = newbuf;
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, GENERIC_XLOG_FULL_IMAGE);
HnswInitPage(buf, page);
}
/* Add the item */
e->blkno = BufferGetBlockNumber(buf);
e->offno = PageAddItem(page, (Item) element, elementsz, InvalidOffsetNumber, false, false);
if (e->offno == InvalidOffsetNumber)
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
if (e == buildstate->entryPoint)
UpdateMetaPage(index, true, e, InvalidBlockNumber, forkNum);
}
insertPage = BufferGetBlockNumber(buf);
/* Commit */
MarkBufferDirty(buf);
GenericXLogFinish(state);
UnlockReleaseBuffer(buf);
UpdateMetaPage(index, false, NULL, insertPage, forkNum);
}
/*
* Create neighbor pages
*/
static void
CreateNeighborPages(HnswBuildState * buildstate)
{
Relation index = buildstate->index;
ForkNumber forkNum = buildstate->forkNum;
Size neighborsz;
HnswNeighborTuple neighbor;
ListCell *lc;
/* Allocate once */
neighborsz = MAXALIGN(sizeof(HnswNeighborTupleData));
neighbor = palloc0(neighborsz);
foreach(lc, buildstate->elements)
{
HnswElement e = lfirst(lc);
Buffer buf;
Page page;
GenericXLogState *state;
/* Can take a while, so ensure we can interrupt */
/* Needs to be called when no buffer locks are held */
CHECK_FOR_INTERRUPTS();
buf = HnswNewBuffer(index, forkNum);
/* Check block number */
if (BufferGetBlockNumber(buf) != e->neighborPage)
elog(ERROR, "expected neighbor page %d, got %d", e->neighborPage, BufferGetBlockNumber(buf));
/* Prepare page */
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, GENERIC_XLOG_FULL_IMAGE);
HnswInitPage(buf, page);
AddNeighborsToPage(index, page, e, neighbor, neighborsz, buildstate->m);
/* Commit */
MarkBufferDirty(buf);
GenericXLogFinish(state);
UnlockReleaseBuffer(buf);
}
}
/*
* Free elements
*/
static void
FreeElements(HnswBuildState * buildstate)
{
ListCell *lc;
foreach(lc, buildstate->elements)
HnswFreeElement(lfirst(lc));
list_free(buildstate->elements);
}
/*
* Flush pages
*/
static void
FlushPages(HnswBuildState * buildstate)
{
CreateMetaPage(buildstate);
CreateElementPages(buildstate);
CreateNeighborPages(buildstate);
buildstate->flushed = true;
FreeElements(buildstate);
}
/*
* Insert tuple
*/
static bool
InsertTuple(Relation index, Datum *values, HnswElement element, HnswBuildState * buildstate, HnswElement * dup)
{
FmgrInfo *procinfo = buildstate->procinfo;
Oid collation = buildstate->collation;
HnswElement entryPoint = buildstate->entryPoint;
int efConstruction = buildstate->efConstruction;
int m = buildstate->m;
/* Detoast once for all calls */
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
/* Normalize if needed */
if (buildstate->normprocinfo != NULL)
{
if (!HnswNormValue(buildstate->normprocinfo, collation, &value, buildstate->normvec))
return false;
}
/* Copy value to element so accessible outside of memory context */
memcpy(element->vec, DatumGetVector(value), VECTOR_SIZE(buildstate->dimensions));
/* Insert element in graph */
*dup = HnswInsertElement(element, entryPoint, NULL, procinfo, collation, m, efConstruction, NULL, false);
/* Update entry point if needed */
if (*dup == NULL && (entryPoint == NULL || element->level > entryPoint->level))
buildstate->entryPoint = element;
UpdateProgress(PROGRESS_CREATEIDX_TUPLES_DONE, ++buildstate->indtuples);
return *dup == NULL;
}
/*
* Callback for table_index_build_scan
*/
static void
BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
bool *isnull, bool tupleIsAlive, void *state)
{
HnswBuildState *buildstate = (HnswBuildState *) state;
MemoryContext oldCtx;
HnswElement element;
HnswElement dup = NULL;
bool inserted;
#if PG_VERSION_NUM < 130000
ItemPointer tid = &hup->t_self;
#endif
/* Skip nulls */
if (isnull[0])
return;
if (buildstate->indtuples >= buildstate->maxInMemoryElements)
{
if (!buildstate->flushed)
{
/* TODO Improve message */
ereport(NOTICE,
(errmsg("hnsw graph no longer fits into maintenance_work_mem"),
errdetail("Building will take significantly more time."),
errhint("Increase maintenance_work_mem to speed up future builds.")));
FlushPages(buildstate);
}
oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx);
if (HnswInsertTuple(buildstate->index, values, isnull, tid, buildstate->heap))
UpdateProgress(PROGRESS_CREATEIDX_TUPLES_DONE, ++buildstate->indtuples);
/* Reset memory context */
MemoryContextSwitchTo(oldCtx);
MemoryContextReset(buildstate->tmpCtx);
return;
}
/* Allocate necessary memory outside of memory context */
element = HnswInitElement(tid, buildstate->m, buildstate->ml, buildstate->maxLevel);
element->vec = palloc(VECTOR_SIZE(buildstate->dimensions));
/* Use memory context since detoast can allocate */
oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx);
/* Insert tuple */
inserted = InsertTuple(index, values, element, buildstate, &dup);
/* Reset memory context */
MemoryContextSwitchTo(oldCtx);
MemoryContextReset(buildstate->tmpCtx);
/* Add outside memory context */
if (dup != NULL)
HnswAddHeapTid(dup, tid);
/* Add to buildstate or free */
if (inserted)
buildstate->elements = lappend(buildstate->elements, element);
else
HnswFreeElement(element);
}
/*
* Get the max number of elements that fit into maintenance_work_mem
*/
static double
HnswGetMaxInMemoryElements(int m, double ml, int dimensions)
{
Size elementSize = sizeof(HnswElementData);
double avgLevel = -log(0.5) * ml;
elementSize += sizeof(HnswCandidate) * (m * (avgLevel + 2));
elementSize += sizeof(ItemPointerData);
elementSize += VECTOR_SIZE(dimensions);
return (maintenance_work_mem * 1024L) / elementSize;
}
/*
* Initialize the build state
*/
static void
InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo, ForkNumber forkNum)
{
buildstate->heap = heap;
buildstate->index = index;
buildstate->indexInfo = indexInfo;
buildstate->forkNum = forkNum;
buildstate->m = HnswGetM(index);
buildstate->efConstruction = HnswGetEfConstruction(index);
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
/* Require column to have dimensions to be indexed */
if (buildstate->dimensions < 0)
elog(ERROR, "column does not have dimensions");
if (buildstate->dimensions > HNSW_MAX_DIM)
elog(ERROR, "column cannot have more than %d dimensions for hnsw index", HNSW_MAX_DIM);
buildstate->reltuples = 0;
buildstate->indtuples = 0;
/* Get support functions */
buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
buildstate->collation = index->rd_indcollation[0];
buildstate->elements = NIL;
buildstate->entryPoint = NULL;
buildstate->ml = HnswGetMl(buildstate->m);
buildstate->maxLevel = HnswGetMaxLevel(buildstate->m);
buildstate->maxInMemoryElements = HnswGetMaxInMemoryElements(buildstate->m, buildstate->ml, buildstate->dimensions);
buildstate->flushed = false;
/* Reuse for each tuple */
buildstate->normvec = InitVector(buildstate->dimensions);
buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext,
"Hnsw build temporary context",
ALLOCSET_DEFAULT_SIZES);
}
/*
* Free resources
*/
static void
FreeBuildState(HnswBuildState * buildstate)
{
MemoryContextDelete(buildstate->tmpCtx);
}
/*
* Build graph
*/
static void
BuildGraph(HnswBuildState * buildstate, ForkNumber forkNum)
{
UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_HNSW_PHASE_LOAD);
#if PG_VERSION_NUM >= 120000
buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, true, BuildCallback, (void *) buildstate, NULL);
#else
buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, BuildCallback, (void *) buildstate, NULL);
#endif
}
/*
* Build the index
*/
static void
BuildIndex(Relation heap, Relation index, IndexInfo *indexInfo,
HnswBuildState * buildstate, ForkNumber forkNum)
{
InitBuildState(buildstate, heap, index, indexInfo, forkNum);
if (buildstate->heap != NULL)
BuildGraph(buildstate, forkNum);
if (!buildstate->flushed)
FlushPages(buildstate);
FreeBuildState(buildstate);
}
/*
* Build the index for a logged table
*/
IndexBuildResult *
hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo)
{
IndexBuildResult *result;
HnswBuildState buildstate;
BuildIndex(heap, index, indexInfo, &buildstate, MAIN_FORKNUM);
result = (IndexBuildResult *) palloc(sizeof(IndexBuildResult));
result->heap_tuples = buildstate.reltuples;
result->index_tuples = buildstate.indtuples;
return result;
}
/*
* Build the index for an unlogged table
*/
void
hnswbuildempty(Relation index)
{
IndexInfo *indexInfo = BuildIndexInfo(index);
HnswBuildState buildstate;
BuildIndex(NULL, index, indexInfo, &buildstate, INIT_FORKNUM);
}

416
src/hnswinsert.c Normal file
View File

@@ -0,0 +1,416 @@
#include "postgres.h"
#include <math.h>
#include "hnsw.h"
#include "storage/bufmgr.h"
#include "storage/lmgr.h"
#include "utils/memutils.h"
/*
* Get the insert page
*/
static BlockNumber
GetInsertPage(Relation index)
{
Buffer buf;
Page page;
HnswMetaPage metap;
BlockNumber insertPage;
buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
metap = HnswPageGetMeta(page);
insertPage = metap->insertPage;
UnlockReleaseBuffer(buf);
return insertPage;
}
/*
* Check for a free offset
*/
static bool
HnswFreeOffset(Page page, OffsetNumber *freeOffno, BlockNumber *neighborPage)
{
OffsetNumber offno;
OffsetNumber maxoffno = PageGetMaxOffsetNumber(page);
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
HnswElementTuple item = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno));
if (item->deleted)
{
*freeOffno = offno;
*neighborPage = item->neighborPage;
return true;
}
}
return false;
}
/*
* Add to element and neighbor pages
*/
static void
WriteNewElementPages(Relation index, HnswElement e, int m)
{
Buffer buf;
Page page;
GenericXLogState *state;
Size esize;
HnswElementTuple etup;
BlockNumber insertPage = GetInsertPage(index);
BlockNumber originalInsertPage = insertPage;
int dimensions = e->vec->dim;
Size nsize = MAXALIGN(sizeof(HnswNeighborTupleData));
HnswNeighborTuple ntup = palloc0(nsize);
Buffer nbuf;
Page npage;
OffsetNumber freeOffno = InvalidOffsetNumber;
BlockNumber neighborPage = InvalidBlockNumber;
/* Get tuple size */
esize = MAXALIGN(HNSW_ELEMENT_TUPLE_SIZE(dimensions));
/* Prepare tuple */
etup = palloc0(esize);
etup->heaptids[0] = *((ItemPointer) linitial(e->heaptids));
for (int i = 1; i < HNSW_HEAPTIDS; i++)
ItemPointerSetInvalid(&etup->heaptids[i]);
etup->level = e->level;
etup->deleted = 0;
memcpy(&etup->vec, e->vec, VECTOR_SIZE(dimensions));
/* Find a page to insert the item */
for (;;)
{
buf = ReadBuffer(index, insertPage);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
if (HnswFreeOffset(page, &freeOffno, &neighborPage) || PageGetFreeSpace(page) >= esize)
break;
insertPage = HnswPageGetOpaque(page)->nextblkno;
if (BlockNumberIsValid(insertPage))
{
/* Move to next page */
GenericXLogAbort(state);
UnlockReleaseBuffer(buf);
}
else
{
Buffer newbuf;
Page newpage;
/*
* From ReadBufferExtended: Caller is responsible for ensuring
* that only one backend tries to extend a relation at the same
* time!
*/
LockRelationForExtension(index, ExclusiveLock);
/* Add a new page */
newbuf = HnswNewBuffer(index, MAIN_FORKNUM);
/* Unlock extend relation lock as early as possible */
UnlockRelationForExtension(index, ExclusiveLock);
/* Init new page */
newpage = GenericXLogRegisterBuffer(state, newbuf, GENERIC_XLOG_FULL_IMAGE);
HnswInitPage(newbuf, newpage);
/* Update insert page */
insertPage = BufferGetBlockNumber(newbuf);
/* Update previous buffer */
HnswPageGetOpaque(page)->nextblkno = insertPage;
/* Commit */
MarkBufferDirty(newbuf);
MarkBufferDirty(buf);
GenericXLogFinish(state);
/* Unlock previous buffer */
UnlockReleaseBuffer(buf);
/* Prepare new buffer */
state = GenericXLogStart(index);
buf = newbuf;
page = GenericXLogRegisterBuffer(state, buf, 0);
break;
}
}
if (OffsetNumberIsValid(freeOffno))
{
/* Reuse existing page */
nbuf = ReadBuffer(index, neighborPage);
LockBuffer(nbuf, BUFFER_LOCK_EXCLUSIVE);
}
else
{
/* Add new page */
LockRelationForExtension(index, ExclusiveLock);
nbuf = HnswNewBuffer(index, MAIN_FORKNUM);
UnlockRelationForExtension(index, ExclusiveLock);
}
npage = GenericXLogRegisterBuffer(state, nbuf, GENERIC_XLOG_FULL_IMAGE);
/* Overwrites existing page via InitPage */
HnswInitPage(nbuf, npage);
/* Update neighbors */
AddNeighborsToPage(index, npage, e, ntup, nsize, m);
e->blkno = BufferGetBlockNumber(buf);
e->neighborPage = BufferGetBlockNumber(nbuf);
/* Set neighbor page for element */
etup->neighborPage = e->neighborPage;
/* Add to next offset */
if (OffsetNumberIsValid(freeOffno))
{
e->offno = freeOffno;
if (!PageIndexTupleOverwrite(page, freeOffno, (Item) etup, esize))
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
}
else
{
e->offno = PageAddItem(page, (Item) etup, esize, InvalidOffsetNumber, false, false);
if (e->offno == InvalidOffsetNumber)
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
}
/* Commit */
MarkBufferDirty(buf);
MarkBufferDirty(nbuf);
GenericXLogFinish(state);
UnlockReleaseBuffer(buf);
UnlockReleaseBuffer(nbuf);
/* Update the insert page */
if (insertPage != originalInsertPage)
UpdateMetaPage(index, false, NULL, insertPage, MAIN_FORKNUM);
}
/*
* Calculate offset number for update
*/
static OffsetNumber
HnswGetOffsetNumber(HnswUpdate * update, int m)
{
return FirstOffsetNumber + (update->hc.element->level - update->level) * m + update->index;
}
/*
* Update neighbors
*/
static void
UpdateNeighborPages(Relation index, HnswElement e, int m, List *updates)
{
Buffer buf;
Page page;
GenericXLogState *state;
ListCell *lc;
OffsetNumber offno;
Size neighborsz = MAXALIGN(sizeof(HnswNeighborTupleData));
HnswNeighborTuple neighbor = palloc0(neighborsz);
/* Could update multiple at once for same element */
/* but should only happen a low percent of time, so keep simple for now */
foreach(lc, updates)
{
HnswUpdate *update = lfirst(lc);
/* Register page */
buf = ReadBuffer(index, update->hc.element->neighborPage);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
offno = HnswGetOffsetNumber(update, m);
/* Make robust against issues */
if (offno <= PageGetMaxOffsetNumber(page))
{
/* Set item data */
ItemPointerSet(&neighbor->indextid, e->blkno, e->offno);
neighbor->distance = update->hc.distance;
/* Update connections */
if (!PageIndexTupleOverwrite(page, offno, (Item) neighbor, neighborsz))
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
/* Commit */
MarkBufferDirty(buf);
GenericXLogFinish(state);
}
else
GenericXLogAbort(state);
UnlockReleaseBuffer(buf);
}
}
/*
* Add a heap tid to an existing element
*/
static bool
HnswAddDuplicate(Relation index, HnswElement element, HnswElement dup)
{
Buffer buf;
Page page;
GenericXLogState *state;
Size esize = MAXALIGN(HNSW_ELEMENT_TUPLE_SIZE(dup->vec->dim));
HnswElementTuple etup;
int i;
/* Read page */
buf = ReadBuffer(index, dup->blkno);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
/* Find space */
etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, dup->offno));
for (i = 0; i < HNSW_HEAPTIDS; i++)
{
if (!ItemPointerIsValid(&etup->heaptids[i]))
break;
}
/* Either being deleted or we lost our chance to another backend */
if (i == 0 || i == HNSW_HEAPTIDS)
{
GenericXLogAbort(state);
UnlockReleaseBuffer(buf);
return false;
}
/* Add heap tid */
etup->heaptids[i] = *((ItemPointer) linitial(element->heaptids));
/* Update index tuple */
if (!PageIndexTupleOverwrite(page, dup->offno, (Item) etup, esize))
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
/* Commit */
MarkBufferDirty(buf);
GenericXLogFinish(state);
UnlockReleaseBuffer(buf);
return true;
}
/*
* Write changes to disk
*/
static void
WriteElement(Relation index, HnswElement element, int m, List *updates, HnswElement dup, HnswElement entryPoint)
{
/* Try to add to existing page */
if (dup != NULL)
{
if (HnswAddDuplicate(index, element, dup))
return;
}
/* If fails, take this path */
WriteNewElementPages(index, element, m);
UpdateNeighborPages(index, element, m, updates);
/* Update metapage if needed */
if (entryPoint == NULL || element->level > entryPoint->level)
UpdateMetaPage(index, true, element, InvalidBlockNumber, MAIN_FORKNUM);
}
/*
* Insert a tuple into the index
*/
bool
HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel)
{
Datum value;
FmgrInfo *normprocinfo;
HnswElement entryPoint;
HnswElement element;
int m = HnswGetM(index);
int efConstruction = HnswGetEfConstruction(index);
double ml = HnswGetMl(m);
FmgrInfo *procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
Oid collation = index->rd_indcollation[0];
List *updates = NIL;
HnswElement dup;
/* Detoast once for all calls */
value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
/* Normalize if needed */
normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
if (normprocinfo != NULL)
{
if (!HnswNormValue(normprocinfo, collation, &value, NULL))
return false;
}
/* Create an element */
element = HnswInitElement(heap_tid, m, ml, HnswGetMaxLevel(m));
element->vec = DatumGetVector(value);
/* Get entry point */
entryPoint = GetEntryPoint(index);
/* Insert element in graph */
dup = HnswInsertElement(element, entryPoint, index, procinfo, collation, m, efConstruction, &updates, false);
/* Write to disk */
WriteElement(index, element, m, updates, dup, entryPoint);
return true;
}
/*
* Insert a tuple into the index
*/
bool
hnswinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid,
Relation heap, IndexUniqueCheck checkUnique
#if PG_VERSION_NUM >= 140000
,bool indexUnchanged
#endif
,IndexInfo *indexInfo
)
{
MemoryContext oldCtx;
MemoryContext insertCtx;
/* Skip nulls */
if (isnull[0])
return false;
/* Create memory context */
insertCtx = AllocSetContextCreate(CurrentMemoryContext,
"Hnsw insert temporary context",
ALLOCSET_DEFAULT_SIZES);
oldCtx = MemoryContextSwitchTo(insertCtx);
/* Insert tuple */
HnswInsertTuple(index, values, isnull, heap_tid, heap);
/* Delete memory context */
MemoryContextSwitchTo(oldCtx);
MemoryContextDelete(insertCtx);
return false;
}

191
src/hnswscan.c Normal file
View File

@@ -0,0 +1,191 @@
#include "postgres.h"
#include "access/relscan.h"
#include "hnsw.h"
#include "pgstat.h"
#include "storage/bufmgr.h"
/*
* Algorithm 5 from paper
*/
static void
GetScanItems(IndexScanDesc scan, Datum q)
{
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
Relation index = scan->indexRelation;
FmgrInfo *procinfo = so->procinfo;
Oid collation = so->collation;
List *ep = NIL;
List *w;
HnswElement entryPoint = GetEntryPoint(index);
if (entryPoint == NULL)
return;
/* TODO Use memory context */
ep = lappend(ep, EntryCandidate(entryPoint, q, index, procinfo, collation, false));
for (int lc = entryPoint->level; lc >= 1; lc--)
{
w = SearchLayer(q, ep, 1, lc, index, procinfo, collation, false, NULL);
ep = w;
}
/* TODO Return all visited elements at level 0, not just ef search */
so->w = SearchLayer(q, ep, hnsw_ef_search, 0, index, procinfo, collation, false, NULL);
}
/*
* Prepare for an index scan
*/
IndexScanDesc
hnswbeginscan(Relation index, int nkeys, int norderbys)
{
IndexScanDesc scan;
HnswScanOpaque so;
scan = RelationGetIndexScan(index, nkeys, norderbys);
so = (HnswScanOpaque) palloc(sizeof(HnswScanOpaqueData));
so->buf = InvalidBuffer;
so->first = true;
so->w = NIL;
/* Set support functions */
so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
so->collation = index->rd_indcollation[0];
scan->opaque = so;
return scan;
}
/*
* Start or restart an index scan
*/
void
hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys)
{
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
so->first = true;
list_free(so->w);
so->w = NIL;
if (keys && scan->numberOfKeys > 0)
memmove(scan->keyData, keys, scan->numberOfKeys * sizeof(ScanKeyData));
if (orderbys && scan->numberOfOrderBys > 0)
memmove(scan->orderByData, orderbys, scan->numberOfOrderBys * sizeof(ScanKeyData));
}
/*
* Fetch the next tuple in the given scan
*/
bool
hnswgettuple(IndexScanDesc scan, ScanDirection dir)
{
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
/*
* Index can be used to scan backward, but Postgres doesn't support
* backward scan on operators
*/
Assert(ScanDirectionIsForward(dir));
if (so->first)
{
Datum value;
/* Count index scan for stats */
pgstat_count_index_scan(scan->indexRelation);
/* Safety check */
if (scan->orderByData == NULL)
elog(ERROR, "cannot scan hnsw index without order");
/* No items will match if null */
if (scan->orderByData->sk_flags & SK_ISNULL)
return false;
value = scan->orderByData->sk_argument;
/* Value should not be compressed or toasted */
Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value)));
Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value)));
if (so->normprocinfo != NULL)
{
/* No items will match if normalization fails */
if (!HnswNormValue(so->normprocinfo, so->collation, &value, NULL))
return false;
}
GetScanItems(scan, value);
so->first = false;
/* Clean up if we allocated a new value */
if (value != scan->orderByData->sk_argument)
pfree(DatumGetPointer(value));
}
while (list_length(so->w) > 0)
{
HnswCandidate *hc = llast(so->w);
ItemPointer tid;
BlockNumber indexblkno;
/* Move to next element if no valid heap tids */
if (list_length(hc->element->heaptids) == 0)
{
so->w = list_delete_last(so->w);
continue;
}
tid = llast(hc->element->heaptids);
indexblkno = hc->element->blkno;
hc->element->heaptids = list_delete_last(hc->element->heaptids);
#if PG_VERSION_NUM >= 120000
scan->xs_heaptid = *tid;
#else
scan->xs_ctup.t_self = *tid;
#endif
if (BufferIsValid(so->buf))
ReleaseBuffer(so->buf);
/*
* An index scan must maintain a pin on the index page holding the
* item last returned by amgettuple
*
* https://www.postgresql.org/docs/current/index-locking.html
*/
so->buf = ReadBuffer(scan->indexRelation, indexblkno);
scan->xs_recheckorderby = false;
return true;
}
return false;
}
/*
* End a scan and release resources
*/
void
hnswendscan(IndexScanDesc scan)
{
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
/* Release pin */
if (BufferIsValid(so->buf))
ReleaseBuffer(so->buf);
list_free(so->w);
pfree(so);
scan->opaque = NULL;
}

927
src/hnswutils.c Normal file
View File

@@ -0,0 +1,927 @@
#include "postgres.h"
#include <math.h>
#include "hnsw.h"
#include "storage/bufmgr.h"
#include "vector.h"
/*
* Get the number of connection in the index
*/
int
HnswGetM(Relation index)
{
HnswOptions *opts = (HnswOptions *) index->rd_options;
if (opts)
return opts->m;
return HNSW_DEFAULT_M;
}
/*
* Get the size of the dynamic candidate list in the index
*/
int
HnswGetEfConstruction(Relation index)
{
HnswOptions *opts = (HnswOptions *) index->rd_options;
if (opts)
return opts->efConstruction;
return HNSW_DEFAULT_EF_CONSTRUCTION;
}
/*
* Get proc
*/
FmgrInfo *
HnswOptionalProcInfo(Relation rel, uint16 procnum)
{
if (!OidIsValid(index_getprocid(rel, 1, procnum)))
return NULL;
return index_getprocinfo(rel, 1, procnum);
}
/*
* Divide by the norm
*
* Returns false if value should not be indexed
*
* The caller needs to free the pointer stored in value
* if it's different than the original value
*/
bool
HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result)
{
double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
if (norm > 0)
{
Vector *v = DatumGetVector(*value);
if (result == NULL)
result = InitVector(v->dim);
for (int i = 0; i < v->dim; i++)
result->x[i] = v->x[i] / norm;
*value = PointerGetDatum(result);
return true;
}
return false;
}
/*
* New buffer
*/
Buffer
HnswNewBuffer(Relation index, ForkNumber forkNum)
{
Buffer buf = ReadBufferExtended(index, forkNum, P_NEW, RBM_NORMAL, NULL);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
return buf;
}
/*
* Init page
*/
void
HnswInitPage(Buffer buf, Page page)
{
PageInit(page, BufferGetPageSize(buf), sizeof(HnswPageOpaqueData));
HnswPageGetOpaque(page)->nextblkno = InvalidBlockNumber;
HnswPageGetOpaque(page)->page_id = HNSW_PAGE_ID;
}
/*
* Init and register page
*/
void
HnswInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state)
{
*state = GenericXLogStart(index);
*page = GenericXLogRegisterBuffer(*state, *buf, GENERIC_XLOG_FULL_IMAGE);
HnswInitPage(*buf, *page);
}
/*
* Commit buffer
*/
void
HnswCommitBuffer(Buffer buf, GenericXLogState *state)
{
MarkBufferDirty(buf);
GenericXLogFinish(state);
UnlockReleaseBuffer(buf);
}
/*
* Create an element from block and offset
*/
static HnswElement
CreateElementFromBlock(BlockNumber blkno, OffsetNumber offno)
{
HnswElement element = palloc(sizeof(HnswElementData));
element->blkno = blkno;
element->offno = offno;
element->neighbors = NULL;
element->vec = NULL;
return element;
}
/*
* Get the entry point
*/
HnswElement
GetEntryPoint(Relation index)
{
Buffer buf;
Page page;
HnswMetaPage metap;
HnswElement entryPoint = NULL;
buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
metap = HnswPageGetMeta(page);
if (BlockNumberIsValid(metap->entryBlkno))
entryPoint = CreateElementFromBlock(metap->entryBlkno, metap->entryOffno);
UnlockReleaseBuffer(buf);
return entryPoint;
}
/*
* Compare candidate distances
*/
static int
CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg)
{
if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance)
return 1;
if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance)
return -1;
return 0;
}
/*
* Compare candidate distances
*/
static int
CompareFurthestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg)
{
if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance)
return -1;
if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance)
return 1;
return 0;
}
/*
* Create a pairing heap node for a candidate
*/
static HnswPairingHeapNode *
CreatePairingHeapNode(HnswCandidate * c)
{
HnswPairingHeapNode *node = palloc(sizeof(HnswPairingHeapNode));
node->inner = c;
return node;
}
/*
* Calculate the distance between elements
*/
static float
HnswGetDistance(HnswElement a, HnswElement b, int lc, FmgrInfo *procinfo, Oid collation)
{
/* Look for cached distance */
if (a->neighbors != NULL)
{
for (int i = 0; i < a->neighbors[lc].length; i++)
{
if (a->neighbors[lc].items[i].element == b)
return a->neighbors[lc].items[i].distance;
}
}
if (b->neighbors != NULL)
{
for (int i = 0; i < b->neighbors[lc].length; i++)
{
if (b->neighbors[lc].items[i].element == a)
return b->neighbors[lc].items[i].distance;
}
}
return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(a->vec), PointerGetDatum(b->vec)));
}
/*
* Check if an element is closer to q than any element from R
*/
static bool
CheckElementCloser(HnswCandidate * e, List *r, int lc, FmgrInfo *procinfo, Oid collation)
{
ListCell *lc2;
foreach(lc2, r)
{
HnswCandidate *ri = lfirst(lc2);
float distance = HnswGetDistance(e->element, ri->element, lc, procinfo, collation);
if (distance <= e->distance)
return false;
}
return true;
}
/*
* Algorithm 4 from paper
*/
static List *
SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswCandidate * *pruned)
{
List *r = NIL;
List *w = list_copy(c);
pairingheap *wd;
if (list_length(w) < m)
return w;
wd = pairingheap_allocate(CompareNearestCandidates, NULL);
while (list_length(w) > 0 && list_length(r) < m)
{
/* Assumes w is already ordered desc */
HnswCandidate *e = llast(w);
bool closer;
w = list_delete_last(w);
closer = CheckElementCloser(e, r, lc, procinfo, collation);
if (closer)
r = lappend(r, e);
else
pairingheap_add(wd, &(CreatePairingHeapNode(e)->ph_node));
}
/* Keep pruned connections */
while (!pairingheap_is_empty(wd) && list_length(r) < m)
r = lappend(r, ((HnswPairingHeapNode *) pairingheap_remove_first(wd))->inner);
/* Return pruned for update connections */
if (pruned != NULL)
{
if (!pairingheap_is_empty(wd))
*pruned = ((HnswPairingHeapNode *) pairingheap_first(wd))->inner;
else
*pruned = linitial(w);
}
return r;
}
/*
* Add connections
*/
static void
AddConnections(HnswElement element, List *neighbors, int m, int lc)
{
ListCell *lc2;
HnswNeighborArray *a = &element->neighbors[lc];
foreach(lc2, neighbors)
a->items[a->length++] = *((HnswCandidate *) lfirst(lc2));
}
/*
* Create update
*/
static HnswUpdate *
CreateUpdate(HnswCandidate * hc, int level, int index)
{
HnswUpdate *update = palloc(sizeof(HnswUpdate));
update->hc = *hc;
update->level = level;
update->index = index;
return update;
}
/*
* Compare candidate distances
*/
static int
#if PG_VERSION_NUM >= 130000
CompareCandidateDistances(const ListCell *a, const ListCell *b)
#else
CompareCandidateDistances(const void *a, const void *b)
#endif
{
HnswCandidate *hca = lfirst((ListCell *) a);
HnswCandidate *hcb = lfirst((ListCell *) b);
if (hca->distance < hcb->distance)
return 1;
if (hca->distance > hcb->distance)
return -1;
return 0;
}
/*
* Add a heap TID to an element
*/
void
HnswAddHeapTid(HnswElement element, ItemPointer heaptid)
{
ItemPointer copy = palloc(sizeof(ItemPointerData));
ItemPointerCopy(heaptid, copy);
element->heaptids = lappend(element->heaptids, copy);
}
/*
* Load an element and optionally get its distance from q
*/
void
HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadvec)
{
Buffer buf;
Page page;
HnswElementTuple item;
/* Read vector */
buf = ReadBuffer(index, element->blkno);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
item = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, element->offno));
/* Load element */
element->heaptids = NIL;
for (int i = 0; i < HNSW_HEAPTIDS; i++)
{
/* Can stop at first invalid */
if (!ItemPointerIsValid(&item->heaptids[i]))
break;
HnswAddHeapTid(element, &item->heaptids[i]);
}
element->level = item->level;
element->neighborPage = item->neighborPage;
element->deleted = item->deleted;
if (loadvec)
{
element->vec = palloc(VECTOR_SIZE(item->vec.dim));
memcpy(element->vec, &item->vec, VECTOR_SIZE(item->vec.dim));
}
/* Calculate distance */
if (distance != NULL)
*distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&item->vec)));
UnlockReleaseBuffer(buf);
}
/*
* Update connections
*/
static void
UpdateConnections(HnswElement element, List *neighbors, int m, int lc, List **updates, Relation index, FmgrInfo *procinfo, Oid collation)
{
ListCell *lc2;
foreach(lc2, neighbors)
{
HnswCandidate *hc = (HnswCandidate *) lfirst(lc2);
HnswNeighborArray *currentNeighbors = &hc->element->neighbors[lc];
HnswCandidate hc2;
hc2.element = element;
hc2.distance = hc->distance;
if (currentNeighbors->length < m)
{
currentNeighbors->items[currentNeighbors->length++] = hc2;
/* Track updates */
if (updates != NULL)
*updates = lappend(*updates, CreateUpdate(hc, lc, currentNeighbors->length - 1));
}
else
{
/* Shrink connections */
HnswCandidate *pruned = NULL;
List *c = NIL;
/* Add and sort candidates */
for (int i = 0; i < currentNeighbors->length; i++)
c = lappend(c, &currentNeighbors->items[i]);
c = lappend(c, &hc2);
list_sort(c, CompareCandidateDistances);
/* Load elements on insert */
if (index != NULL)
{
for (int i = 0; i < currentNeighbors->length; i++)
{
if (currentNeighbors->items[i].element->vec == NULL)
{
HnswLoadElement(currentNeighbors->items[i].element, NULL, NULL, index, procinfo, collation, true);
/* Prune deleted element */
if (currentNeighbors->items[i].element->deleted)
{
pruned = &currentNeighbors->items[i];
break;
}
}
}
}
if (pruned == NULL)
{
SelectNeighbors(c, m, lc, procinfo, collation, &pruned);
/* Should not happen */
if (pruned == NULL)
continue;
}
/* Find and replace the pruned element */
for (int i = 0; i < currentNeighbors->length; i++)
{
if (currentNeighbors->items[i].element == pruned->element)
{
currentNeighbors->items[i] = hc2;
/* Track updates */
if (updates != NULL)
*updates = lappend(*updates, CreateUpdate(hc, lc, i));
break;
}
}
}
}
}
/*
* Initialize neighbors
*/
void
HnswInitNeighbors(HnswElement element, int m)
{
int level = element->level;
element->neighbors = palloc(sizeof(HnswNeighborArray) * (level + 1));
for (int lc = 0; lc <= level; lc++)
{
HnswNeighborArray *a;
int lm = GetLayerM(m, lc);
a = &element->neighbors[lc];
a->length = 0;
a->items = palloc(sizeof(HnswCandidate) * lm);
}
}
/*
* Load neighbors
*/
static void
LoadNeighbors(HnswCandidate * c, Relation index)
{
Buffer buf;
Page page;
OffsetNumber offno;
OffsetNumber maxoffno;
HnswNeighborTuple neighbor;
HnswNeighborArray *neighbors;
int m = HnswGetM(index);
buf = ReadBuffer(index, c->element->neighborPage);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
maxoffno = PageGetMaxOffsetNumber(page);
HnswInitNeighbors(c->element, m);
/* If not, neighbor page represents new item */
/* Only caught if item has a different level */
/* TODO Use versioning to fix this? */
if (maxoffno == (c->element->level + 2) * m)
{
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
HnswElement element;
int level;
HnswCandidate *hc;
neighbor = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, offno));
if (!ItemPointerIsValid(&neighbor->indextid))
continue;
element = CreateElementFromBlock(ItemPointerGetBlockNumber(&neighbor->indextid), ItemPointerGetOffsetNumber(&neighbor->indextid));
/* Calculate level based on offset */
level = c->element->level - (offno - FirstOffsetNumber) / m;
if (level < 0)
level = 0;
neighbors = &c->element->neighbors[level];
hc = &neighbors->items[neighbors->length];
hc->element = element;
hc->distance = neighbor->distance;
neighbors->length++;
}
}
UnlockReleaseBuffer(buf);
}
/*
* Allocate an element
*/
HnswElement
HnswInitElement(ItemPointer heaptid, int m, double ml, int maxLevel)
{
HnswElement element = palloc(sizeof(HnswElementData));
int level = (int) (-log(RandomDouble()) * ml);
/* Cap level */
if (level > maxLevel)
level = maxLevel;
element->heaptids = NIL;
HnswAddHeapTid(element, heaptid);
element->level = level;
element->deleted = 0;
HnswInitNeighbors(element, m);
return element;
}
/*
* Free an element
*/
void
HnswFreeElement(HnswElement element)
{
list_free(element->heaptids);
for (int lc = 0; lc <= element->level; lc++)
pfree(element->neighbors[lc].items);
pfree(element->neighbors);
pfree(element->vec);
pfree(element);
}
/*
* Get the distance for a candidate
*/
static float
GetCandidateDistance(HnswCandidate * hc, Datum q, FmgrInfo *procinfo, Oid collation)
{
return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, PointerGetDatum(hc->element->vec)));
}
/*
* Check if a candidate is a list member by pointer
*/
static bool
HnswListMemberByPointer(List *v, HnswCandidate * e)
{
ListCell *lc2;
foreach(lc2, v)
{
HnswCandidate *v2 = (HnswCandidate *) lfirst(lc2);
if (v2->element == e->element)
return true;
}
return false;
}
/*
* Check if a candidate is a list member by block and offset
*/
static bool
HnswListMemberByBlock(List *v, HnswCandidate * e)
{
ListCell *lc2;
foreach(lc2, v)
{
HnswCandidate *v2 = (HnswCandidate *) lfirst(lc2);
/* Neighbor page not set yet */
if (v2->element->blkno == e->element->blkno && v2->element->offno == e->element->offno)
return true;
}
return false;
}
/*
* Algorithm 2 from paper
*/
List *
SearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, bool inserting, BlockNumber *skipPage)
{
ListCell *lc2;
List *w = NIL;
List *v = NIL;
pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL);
pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL);
int wlen = 0;
/* Add entry points to v, C, and W */
foreach(lc2, ep)
{
HnswCandidate *hc = (HnswCandidate *) lfirst(lc2);
v = lappend(v, hc);
pairingheap_add(C, &(CreatePairingHeapNode(hc)->ph_node));
pairingheap_add(W, &(CreatePairingHeapNode(hc)->ph_node));
wlen++;
}
while (!pairingheap_is_empty(C))
{
HnswNeighborArray *neighborhood;
HnswCandidate *c = ((HnswPairingHeapNode *) pairingheap_remove_first(C))->inner;
HnswCandidate *f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner;
if (c->distance > f->distance)
break;
if (c->element->neighbors == NULL)
LoadNeighbors(c, index);
/* Get the neighborhood at layer lc */
neighborhood = &c->element->neighbors[lc];
for (int i = 0; i < neighborhood->length; i++)
{
HnswCandidate *e = &neighborhood->items[i];
bool visited;
/* TODO Use hash for v? */
visited = index == NULL ? HnswListMemberByPointer(v, e) : HnswListMemberByBlock(v, e);
if (!visited)
{
float eDistance;
v = lappend(v, e);
f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner;
if (index == NULL)
eDistance = GetCandidateDistance(e, q, procinfo, collation);
else
HnswLoadElement(e->element, &eDistance, &q, index, procinfo, collation, inserting);
/* Skip if fully deleted */
if (e->element->deleted)
continue;
/* Skip for inserts if deleting */
if (inserting && list_length(e->element->heaptids) == 0)
continue;
/* Skip self for vacuuming update */
if (skipPage != NULL && e->element->neighborPage == *skipPage)
continue;
/* Stale read */
if (e->element->level < lc)
continue;
if (eDistance < f->distance || wlen < ef)
{
/* copy e */
HnswCandidate *e2 = palloc(sizeof(HnswCandidate));
e2->element = e->element;
e2->distance = eDistance;
pairingheap_add(C, &(CreatePairingHeapNode(e2)->ph_node));
pairingheap_add(W, &(CreatePairingHeapNode(e2)->ph_node));
wlen++;
if (wlen > ef)
{
pairingheap_remove_first(W);
wlen--;
}
}
}
}
}
/* Add each element of W to w */
while (!pairingheap_is_empty(W))
{
HnswCandidate *hc = ((HnswPairingHeapNode *) pairingheap_remove_first(W))->inner;
w = lappend(w, hc);
}
return w;
}
/*
* Create a candidate for the entry point
*/
HnswCandidate *
EntryCandidate(HnswElement entryPoint, Datum q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadvec)
{
HnswCandidate *hc = palloc(sizeof(HnswCandidate));
hc->element = entryPoint;
if (index == NULL)
hc->distance = GetCandidateDistance(hc, q, procinfo, collation);
else
HnswLoadElement(hc->element, &hc->distance, &q, index, procinfo, collation, loadvec);
return hc;
}
/*
* Find duplicate element
*/
static HnswElement
HnswFindDuplicate(HnswElement e, List *neighbors)
{
ListCell *lc;
foreach(lc, neighbors)
{
HnswCandidate *neighbor = lfirst(lc);
/* Exit early since ordered by distance */
if (vector_cmp_internal(e->vec, neighbor->element->vec) != 0)
break;
/* Check for space */
if (list_length(neighbor->element->heaptids) < HNSW_HEAPTIDS)
return neighbor->element;
}
return NULL;
}
/*
* Algorithm 1 from paper
*/
HnswElement
HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, List **updates, bool vacuuming)
{
List *ep = NIL;
List *w;
int level = element->level;
int entryLevel;
List **newNeighbors = palloc(sizeof(List *) * (level + 1));
Datum q = PointerGetDatum(element->vec);
HnswElement dup;
BlockNumber *skipPage = vacuuming ? &element->neighborPage : NULL;
/* Get entry point and level */
if (entryPoint != NULL)
{
ep = lappend(ep, EntryCandidate(entryPoint, q, index, procinfo, collation, true));
entryLevel = entryPoint->level;
}
else
entryLevel = -1;
for (int lc = entryLevel; lc >= level + 1; lc--)
{
w = SearchLayer(q, ep, 1, lc, index, procinfo, collation, true, skipPage);
ep = w;
}
if (level > entryLevel)
level = entryLevel;
for (int lc = level; lc >= 0; lc--)
{
int lm = GetLayerM(m, lc);
w = SearchLayer(q, ep, efConstruction, lc, index, procinfo, collation, true, skipPage);
newNeighbors[lc] = SelectNeighbors(w, lm, lc, procinfo, collation, NULL);
ep = w;
}
if (level >= 0 && !vacuuming)
{
dup = HnswFindDuplicate(element, newNeighbors[0]);
if (dup != NULL)
return dup;
}
/* Update connections */
for (int lc = level; lc >= 0; lc--)
{
int lm = GetLayerM(m, lc);
AddConnections(element, newNeighbors[lc], lm, lc);
if (!vacuuming)
UpdateConnections(element, newNeighbors[lc], lm, lc, updates, index, procinfo, collation);
}
return NULL;
}
/*
* Update the metapage
*/
void
UpdateMetaPage(Relation index, bool updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum)
{
Buffer buf;
Page page;
GenericXLogState *state;
HnswMetaPage metap;
buf = ReadBufferExtended(index, forkNum, HNSW_METAPAGE_BLKNO, RBM_NORMAL, NULL);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
metap = HnswPageGetMeta(page);
if (updateEntry)
{
metap->entryBlkno = entryPoint->blkno;
metap->entryOffno = entryPoint->offno;
}
if (BlockNumberIsValid(insertPage))
metap->insertPage = insertPage;
HnswCommitBuffer(buf, state);
}
/*
* Add neighbors to page
*/
void
AddNeighborsToPage(Relation index, Page page, HnswElement e, HnswNeighborTuple neighbor, Size neighborsz, int m)
{
for (int lc = e->level; lc >= 0; lc--)
{
HnswNeighborArray *neighbors = &e->neighbors[lc];
int lm = GetLayerM(m, lc);
for (int i = 0; i < lm; i++)
{
if (i < neighbors->length)
{
HnswCandidate *hc = &neighbors->items[i];
ItemPointerSet(&neighbor->indextid, hc->element->blkno, hc->element->offno);
neighbor->distance = hc->distance;
}
else
{
ItemPointerSetInvalid(&neighbor->indextid);
neighbor->distance = NAN;
}
if (PageAddItem(page, (Item) neighbor, neighborsz, InvalidOffsetNumber, false, false) == InvalidOffsetNumber)
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
}
}
}

517
src/hnswvacuum.c Normal file
View File

@@ -0,0 +1,517 @@
#include "postgres.h"
#include "commands/vacuum.h"
#include "hnsw.h"
#include "storage/bufmgr.h"
#include "utils/memutils.h"
/*
* Check if deleted list contains an index tid
*/
static bool
DeletedContains(HTAB *deleted, ItemPointer indextid)
{
bool found;
hash_search(deleted, indextid, HASH_FIND, &found);
return found;
}
/*
* Remove deleted heap tids
*
* OK to remove for entry point, since always considered for searches and inserts
*/
static void
RemoveHeapTids(HnswVacuumState * vacuumstate)
{
BlockNumber blkno = HNSW_HEAD_BLKNO;
HnswElement highestPoint = &vacuumstate->highestPoint;
Relation index = vacuumstate->index;
BufferAccessStrategy bas = vacuumstate->bas;
HnswElement entryPoint = GetEntryPoint(vacuumstate->index);
/* Store separately since highestPoint.level is uint8 */
int highestLevel = -1;
/* Initialize highest point */
highestPoint->blkno = InvalidBlockNumber;
highestPoint->offno = InvalidOffsetNumber;
while (BlockNumberIsValid(blkno))
{
Buffer buf;
Page page;
GenericXLogState *state;
OffsetNumber offno;
OffsetNumber maxoffno;
bool updated = false;
vacuum_delay_point();
buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
maxoffno = PageGetMaxOffsetNumber(page);
/* Iterate over nodes */
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
HnswElementTuple item = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno));
int idx = 0;
bool itemUpdated = false;
if (ItemPointerIsValid(&item->heaptids[0]))
{
for (int i = 0; i < HNSW_HEAPTIDS; i++)
{
/* Stop at first unused */
if (!ItemPointerIsValid(&item->heaptids[i]))
break;
if (vacuumstate->callback(&item->heaptids[i], vacuumstate->callback_state))
itemUpdated = true;
else
{
/* Move to front of list */
item->heaptids[idx++] = item->heaptids[i];
}
}
if (itemUpdated)
{
Size itemsz = MAXALIGN(HNSW_ELEMENT_TUPLE_SIZE(item->vec.dim));
/* Mark rest as invalid */
for (int i = idx; i < HNSW_HEAPTIDS; i++)
ItemPointerSetInvalid(&item->heaptids[i]);
if (!PageIndexTupleOverwrite(page, offno, (Item) item, itemsz))
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
updated = true;
}
}
if (!ItemPointerIsValid(&item->heaptids[0]))
{
ItemPointerData ip;
/* Add to deleted list */
ItemPointerSet(&ip, blkno, offno);
(void) hash_search(vacuumstate->deleted, &ip, HASH_ENTER, NULL);
}
else if (item->level > highestLevel && !(highestPoint->blkno == entryPoint->blkno && highestPoint->offno == entryPoint->offno))
{
/* Keep track of highest non-entry point */
/* TODO Keep track of closest one to entry point? */
highestPoint->blkno = blkno;
highestPoint->offno = offno;
highestLevel = item->level;
}
}
blkno = HnswPageGetOpaque(page)->nextblkno;
if (updated)
{
MarkBufferDirty(buf);
GenericXLogFinish(state);
}
else
GenericXLogAbort(state);
UnlockReleaseBuffer(buf);
}
}
/*
* Check for deleted neighbors
*/
static bool
NeedsUpdated(HnswVacuumState * vacuumstate, HnswElement element)
{
Relation index = vacuumstate->index;
BufferAccessStrategy bas = vacuumstate->bas;
Buffer buf;
Page page;
OffsetNumber offno;
OffsetNumber maxoffno;
bool needsUpdated = false;
buf = ReadBufferExtended(index, MAIN_FORKNUM, element->neighborPage, RBM_NORMAL, bas);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
maxoffno = PageGetMaxOffsetNumber(page);
/* Check neighbors */
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
HnswNeighborTuple ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, offno));
if (!ItemPointerIsValid(&ntup->indextid))
continue;
/* Check if in deleted list */
if (DeletedContains(vacuumstate->deleted, &ntup->indextid))
{
needsUpdated = true;
break;
}
}
UnlockReleaseBuffer(buf);
return needsUpdated;
}
/*
* Repair graph for a single element
*/
static void
RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element)
{
Relation index = vacuumstate->index;
Buffer buf;
Page page;
GenericXLogState *state;
int m = vacuumstate->m;
int efConstruction = vacuumstate->efConstruction;
FmgrInfo *procinfo = vacuumstate->procinfo;
Oid collation = vacuumstate->collation;
HnswElement entryPoint;
BufferAccessStrategy bas = vacuumstate->bas;
HnswNeighborTuple ntup = vacuumstate->ntup;
Size nsize = vacuumstate->nsize;
/* Check if any neighbors point to deleted values */
if (!NeedsUpdated(vacuumstate, element))
return;
/* Refresh entry point for each element */
entryPoint = GetEntryPoint(index);
/* Special case for entry point */
if (element->blkno == entryPoint->blkno && element->offno == entryPoint->offno)
{
if (BlockNumberIsValid(vacuumstate->highestPoint.blkno))
{
/* Already updated */
if (vacuumstate->highestPoint.blkno == element->blkno && vacuumstate->highestPoint.offno == element->offno)
return;
entryPoint = &vacuumstate->highestPoint;
}
else
entryPoint = NULL;
}
HnswInitNeighbors(element, m);
element->heaptids = NIL;
HnswInsertElement(element, entryPoint, index, procinfo, collation, m, efConstruction, NULL, true);
/* Write out new neighbors on page */
buf = ReadBufferExtended(index, MAIN_FORKNUM, element->neighborPage, RBM_NORMAL, bas);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, GENERIC_XLOG_FULL_IMAGE);
/* Overwrites existing page via InitPage */
HnswInitPage(buf, page);
/* Update neighbors */
AddNeighborsToPage(index, page, element, ntup, nsize, m);
/* Commit */
MarkBufferDirty(buf);
GenericXLogFinish(state);
UnlockReleaseBuffer(buf);
}
/*
* Repair graph entry point
*/
static void
RepairGraphEntryPoint(HnswVacuumState * vacuumstate)
{
Relation index = vacuumstate->index;
HnswElement highestPoint = &vacuumstate->highestPoint;
HnswElement entryPoint;
MemoryContext oldCtx = MemoryContextSwitchTo(vacuumstate->tmpCtx);
/* Repair graph for highest non-entry point */
/* This may not be the highest with new inserts, but should be fine */
if (BlockNumberIsValid(highestPoint->blkno))
{
HnswLoadElement(highestPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true);
RepairGraphElement(vacuumstate, highestPoint);
}
entryPoint = GetEntryPoint(index);
if (entryPoint != NULL)
{
ItemPointerData epData;
ItemPointerSet(&epData, entryPoint->blkno, entryPoint->offno);
if (DeletedContains(vacuumstate->deleted, &epData))
UpdateMetaPage(index, true, highestPoint, InvalidBlockNumber, MAIN_FORKNUM);
else
{
/* Highest point will be used to repair */
HnswLoadElement(entryPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true);
RepairGraphElement(vacuumstate, entryPoint);
}
}
/* Reset memory context */
MemoryContextSwitchTo(oldCtx);
MemoryContextReset(vacuumstate->tmpCtx);
}
/*
* Repair graph for all elements
*/
static void
RepairGraph(HnswVacuumState * vacuumstate)
{
Relation index = vacuumstate->index;
BufferAccessStrategy bas = vacuumstate->bas;
BlockNumber blkno = HNSW_HEAD_BLKNO;
RepairGraphEntryPoint(vacuumstate);
while (BlockNumberIsValid(blkno))
{
Buffer buf;
Page page;
OffsetNumber offno;
OffsetNumber maxoffno;
List *elements = NIL;
ListCell *lc2;
MemoryContext oldCtx;
vacuum_delay_point();
oldCtx = MemoryContextSwitchTo(vacuumstate->tmpCtx);
buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
maxoffno = PageGetMaxOffsetNumber(page);
/* Load items into memory to minimize locking */
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
HnswElementTuple item = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno));
HnswElement element;
/* Skip updating neighbors if being deleted */
if (!ItemPointerIsValid(&item->heaptids[0]))
continue;
/* Create an element */
element = palloc(sizeof(HnswElementData));
element->neighborPage = item->neighborPage;
element->level = item->level;
element->blkno = blkno;
element->offno = offno;
element->vec = palloc(VECTOR_SIZE(item->vec.dim));
memcpy(element->vec, &item->vec, VECTOR_SIZE(item->vec.dim));
elements = lappend(elements, element);
}
blkno = HnswPageGetOpaque(page)->nextblkno;
UnlockReleaseBuffer(buf);
/* Update neighbor pages */
foreach(lc2, elements)
RepairGraphElement(vacuumstate, (HnswElement) lfirst(lc2));
/* Reset memory context */
MemoryContextSwitchTo(oldCtx);
MemoryContextReset(vacuumstate->tmpCtx);
}
}
/*
* Mark items as deleted
*/
static void
MarkDeleted(HnswVacuumState * vacuumstate)
{
BlockNumber blkno = HNSW_HEAD_BLKNO;
BlockNumber insertPage = InvalidBlockNumber;
Relation index = vacuumstate->index;
BufferAccessStrategy bas = vacuumstate->bas;
while (BlockNumberIsValid(blkno))
{
Buffer buf;
Page page;
GenericXLogState *state;
OffsetNumber offno;
OffsetNumber maxoffno;
vacuum_delay_point();
buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas);
/*
* ambulkdelete cannot delete entries from pages that are pinned by
* other backends
*
* https://www.postgresql.org/docs/current/index-locking.html
*/
LockBufferForCleanup(buf);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
maxoffno = PageGetMaxOffsetNumber(page);
/* Update element and neighbors together */
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
HnswElementTuple item = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno));
Size itemsz;
Buffer nbuf;
Page npage;
if (ItemPointerIsValid(&item->heaptids[0]))
continue;
/* Overwrite element */
/* TODO Increment version? */
item->deleted = 1;
MemSet(&item->vec.x, 0, item->vec.dim * sizeof(float));
itemsz = MAXALIGN(HNSW_ELEMENT_TUPLE_SIZE(item->vec.dim));
if (!PageIndexTupleOverwrite(page, offno, (Item) item, itemsz))
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
/* Overwrite neighbors */
nbuf = ReadBufferExtended(index, MAIN_FORKNUM, item->neighborPage, RBM_NORMAL, bas);
LockBuffer(nbuf, BUFFER_LOCK_EXCLUSIVE);
npage = GenericXLogRegisterBuffer(state, nbuf, GENERIC_XLOG_FULL_IMAGE);
HnswInitPage(nbuf, npage);
/* Commit */
MarkBufferDirty(buf);
MarkBufferDirty(nbuf);
GenericXLogFinish(state);
UnlockReleaseBuffer(nbuf);
/* Set to first free page */
if (!BlockNumberIsValid(insertPage))
insertPage = blkno;
/* Prepare new xlog */
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
}
blkno = HnswPageGetOpaque(page)->nextblkno;
GenericXLogAbort(state);
UnlockReleaseBuffer(buf);
}
UpdateMetaPage(index, false, NULL, insertPage, MAIN_FORKNUM);
}
/*
* Initialize the vacuum state
*/
static void
InitVacuumState(HnswVacuumState * vacuumstate, IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state)
{
Relation index = info->index;
HASHCTL hash_ctl;
if (stats == NULL)
stats = (IndexBulkDeleteResult *) palloc0(sizeof(IndexBulkDeleteResult));
vacuumstate->index = index;
vacuumstate->stats = stats;
vacuumstate->callback = callback;
vacuumstate->callback_state = callback_state;
vacuumstate->m = HnswGetM(index);
vacuumstate->efConstruction = HnswGetEfConstruction(index);
vacuumstate->bas = GetAccessStrategy(BAS_BULKREAD);
vacuumstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
vacuumstate->collation = index->rd_indcollation[0];
vacuumstate->nsize = MAXALIGN(sizeof(HnswNeighborTupleData));
vacuumstate->ntup = palloc0(vacuumstate->nsize);
vacuumstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext,
"Hnsw vacuum temporary context",
ALLOCSET_DEFAULT_SIZES);
/* Create hash table */
hash_ctl.keysize = sizeof(ItemPointerData);
hash_ctl.entrysize = sizeof(ItemPointerData);
hash_ctl.hcxt = CurrentMemoryContext;
vacuumstate->deleted = hash_create("hnswbulkdelete indextids", 256, &hash_ctl, HASH_ELEM | HASH_BLOBS | HASH_CONTEXT);
}
/*
* Free resources
*/
static void
FreeVacuumState(HnswVacuumState * vacuumstate)
{
hash_destroy(vacuumstate->deleted);
FreeAccessStrategy(vacuumstate->bas);
pfree(vacuumstate->ntup);
MemoryContextDelete(vacuumstate->tmpCtx);
}
/*
* Bulk delete tuples from the index
*/
IndexBulkDeleteResult *
hnswbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats,
IndexBulkDeleteCallback callback, void *callback_state)
{
HnswVacuumState vacuumstate;
InitVacuumState(&vacuumstate, info, stats, callback, callback_state);
/* Pass 1: Remove heap tids */
RemoveHeapTids(&vacuumstate);
/* Pass 2: Repair graph */
RepairGraph(&vacuumstate);
/* Pass 3: Mark as deleted */
MarkDeleted(&vacuumstate);
FreeVacuumState(&vacuumstate);
return vacuumstate.stats;
}
/*
* Clean up after a VACUUM operation
*/
IndexBulkDeleteResult *
hnswvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats)
{
Relation rel = info->index;
if (info->analyze_only)
return stats;
/* stats is NULL if ambulkdelete not called */
/* OK to return NULL if index not changed */
if (stats == NULL)
return NULL;
stats->num_pages = RelationGetNumberOfBlocks(rel);
return stats;
}

View File

@@ -3,10 +3,6 @@
#include "postgres.h"
#if PG_VERSION_NUM < 110000
#error "Requires PostgreSQL 11+"
#endif
#include "access/generic_xlog.h"
#include "access/parallel.h"
#include "access/reloptions.h"

View File

@@ -4,6 +4,7 @@
#include "catalog/pg_type.h"
#include "fmgr.h"
#include "hnsw.h"
#include "ivfflat.h"
#include "lib/stringinfo.h"
#include "libpq/pqformat.h"
@@ -37,6 +38,7 @@ PG_MODULE_MAGIC;
void
_PG_init(void)
{
HnswInit();
IvfflatInit();
}

View File

@@ -0,0 +1,19 @@
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING hnsw (val vector_cosine_ops);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <=> '[3,3,3]';
val
---------
[1,1,1]
[1,2,3]
[1,2,4]
(3 rows)
SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector);
val
-----
(0 rows)
DROP TABLE t;

20
test/expected/hnsw_ip.out Normal file
View File

@@ -0,0 +1,20 @@
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING hnsw (val vector_ip_ops);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <#> '[3,3,3]';
val
---------
[1,2,4]
[1,2,3]
[1,1,1]
[0,0,0]
(4 rows)
SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector);
val
-----
(0 rows)
DROP TABLE t;

26
test/expected/hnsw_l2.out Normal file
View File

@@ -0,0 +1,26 @@
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING hnsw (val vector_l2_ops);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <-> '[3,3,3]';
val
---------
[1,2,3]
[1,2,4]
[1,1,1]
[0,0,0]
(4 rows)
SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector);
val
-----
(0 rows)
SELECT COUNT(*) FROM t;
count
-------
5
(1 row)
DROP TABLE t;

View File

@@ -0,0 +1,25 @@
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 3);
ERROR: value 3 out of bounds for option "m"
DETAIL: Valid values are between "4" and "100".
CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 101);
ERROR: value 101 out of bounds for option "m"
DETAIL: Valid values are between "4" and "100".
CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 9);
ERROR: value 9 out of bounds for option "ef_construction"
DETAIL: Valid values are between "10" and "1000".
CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 1001);
ERROR: value 1001 out of bounds for option "ef_construction"
DETAIL: Valid values are between "10" and "1000".
SHOW hnsw.ef_search;
hnsw.ef_search
----------------
40
(1 row)
SET hnsw.ef_search = 9;
ERROR: 9 is outside the valid range for parameter "hnsw.ef_search" (10 .. 1000)
SET hnsw.ef_search = 1001;
ERROR: 1001 is outside the valid range for parameter "hnsw.ef_search" (10 .. 1000)
DROP TABLE t;

View File

@@ -0,0 +1,13 @@
SET enable_seqscan = off;
CREATE UNLOGGED TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING hnsw (val vector_l2_ops);
SELECT * FROM t ORDER BY val <-> '[3,3,3]';
val
---------
[1,2,3]
[1,1,1]
[0,0,0]
(3 rows)
DROP TABLE t;

12
test/sql/hnsw_cosine.sql Normal file
View File

@@ -0,0 +1,12 @@
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING hnsw (val vector_cosine_ops);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <=> '[3,3,3]';
SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector);
DROP TABLE t;

12
test/sql/hnsw_ip.sql Normal file
View File

@@ -0,0 +1,12 @@
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING hnsw (val vector_ip_ops);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <#> '[3,3,3]';
SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector);
DROP TABLE t;

13
test/sql/hnsw_l2.sql Normal file
View File

@@ -0,0 +1,13 @@
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING hnsw (val vector_l2_ops);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <-> '[3,3,3]';
SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector);
SELECT COUNT(*) FROM t;
DROP TABLE t;

14
test/sql/hnsw_options.sql Normal file
View File

@@ -0,0 +1,14 @@
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 3);
CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 101);
CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 9);
CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 1001);
SHOW hnsw.ef_search;
SET hnsw.ef_search = 9;
SET hnsw.ef_search = 1001;
DROP TABLE t;

View File

@@ -0,0 +1,9 @@
SET enable_seqscan = off;
CREATE UNLOGGED TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING hnsw (val vector_l2_ops);
SELECT * FROM t ORDER BY val <-> '[3,3,3]';
DROP TABLE t;

99
test/t/010_hnsw_wal.pl Normal file
View File

@@ -0,0 +1,99 @@
# Based on postgres/contrib/bloom/t/001_wal.pl
# Test generic xlog record work for hnsw index replication.
use strict;
use warnings;
use PostgresNode;
use TestLib;
use Test::More;
my $dim = 32;
my $node_primary;
my $node_replica;
# Run few queries on both primary and replica and check their results match.
sub test_index_replay
{
my ($test_name) = @_;
# Wait for replica to catch up
my $applname = $node_replica->name;
my $server_version_num = $node_primary->safe_psql("postgres", "SHOW server_version_num");
my $caughtup_query = "SELECT pg_current_wal_lsn() <= replay_lsn FROM pg_stat_replication WHERE application_name = '$applname';";
$node_primary->poll_query_until('postgres', $caughtup_query)
or die "Timed out while waiting for replica 1 to catch up";
my @r = ();
for (1 .. $dim) {
push(@r, rand());
}
my $sql = join(",", @r);
my $queries = qq(
SET enable_seqscan = off;
SELECT * FROM tst ORDER BY v <-> '[$sql]' LIMIT 10;
);
# Run test queries and compare their result
my $primary_result = $node_primary->safe_psql("postgres", $queries);
my $replica_result = $node_replica->safe_psql("postgres", $queries);
is($primary_result, $replica_result, "$test_name: query result matches");
return;
}
# Use ARRAY[random(), random(), random(), ...] over
# SELECT array_agg(random()) FROM generate_series(1, $dim)
# to generate different values for each row
my $array_sql = join(",", ('random()') x $dim);
# Initialize primary node
$node_primary = get_new_node('primary');
$node_primary->init(allows_streaming => 1);
if ($dim > 32) {
# TODO use wal_keep_segments for Postgres < 13
$node_primary->append_conf('postgresql.conf', qq(wal_keep_size = 1GB));
}
if ($dim > 1500) {
$node_primary->append_conf('postgresql.conf', qq(maintenance_work_mem = 128MB));
}
$node_primary->start;
my $backup_name = 'my_backup';
# Take backup
$node_primary->backup($backup_name);
# Create streaming replica linking to primary
$node_replica = get_new_node('replica');
$node_replica->init_from_backup($node_primary, $backup_name,
has_streaming => 1);
$node_replica->start;
# Create ivfflat index on primary
$node_primary->safe_psql("postgres", "CREATE EXTENSION vector;");
$node_primary->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));");
$node_primary->safe_psql("postgres",
"INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 10000) i;"
);
$node_primary->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);");
# Test that queries give same result
test_index_replay('initial');
# Run 10 cycles of table modification. Run test queries after each modification.
for my $i (1 .. 10)
{
$node_primary->safe_psql("postgres", "DELETE FROM tst WHERE i = $i;");
test_index_replay("delete $i");
$node_primary->safe_psql("postgres", "VACUUM tst;");
test_index_replay("vacuum $i");
my ($start, $end) = (10001 + ($i - 1) * 1000, 10000 + $i * 1000);
$node_primary->safe_psql("postgres",
"INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series($start, $end) i;"
);
test_index_replay("insert $i");
}
done_testing();

43
test/t/011_hnsw_vacuum.pl Normal file
View File

@@ -0,0 +1,43 @@
use strict;
use warnings;
use PostgresNode;
use TestLib;
use Test::More;
my $dim = 3;
my @r = ();
for (1 .. $dim) {
my $v = int(rand(1000)) + 1;
push(@r, "i % $v");
}
my $array_sql = join(", ", @r);
# Initialize node
my $node = get_new_node('node');
$node->init;
$node->start;
# Create table and index
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 10000) i;"
);
$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);");
# Get size
my $size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');");
# Delete all, vacuum, and insert same data
$node->safe_psql("postgres", "DELETE FROM tst;");
$node->safe_psql("postgres", "VACUUM tst;");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 10000) i;"
);
# Check size
my $new_size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');");
is($size, $new_size, "size does not change");
done_testing();

View File

@@ -0,0 +1,96 @@
use strict;
use warnings;
use PostgresNode;
use TestLib;
use Test::More;
my $node;
my @queries = ();
my @expected;
my $limit = 20;
sub test_recall
{
my ($min, $operator) = @_;
my $correct = 0;
my $total = 0;
my $explain = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit;
));
like($explain, qr/Index Scan/);
for my $i (0 .. $#queries) {
my $actual = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit;
));
my @actual_ids = split("\n", $actual);
my %actual_set = map { $_ => 1 } @actual_ids;
my @expected_ids = split("\n", $expected[$i]);
foreach (@expected_ids) {
if (exists($actual_set{$_})) {
$correct++;
}
$total++;
}
}
cmp_ok($correct / $total, ">=", $min, $operator);
}
# Initialize node
$node = get_new_node('node');
$node->init;
$node->start;
# Create table
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;"
);
# Generate queries
for (1..20) {
my $r1 = rand();
my $r2 = rand();
my $r3 = rand();
push(@queries, "[$r1,$r2,$r3]");
}
# Check each index type
my @operators = ("<->", "<#>", "<=>");
foreach (@operators) {
my $operator = $_;
# Get exact results
@expected = ();
foreach (@queries) {
my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;");
push(@expected, $res);
}
# Add index
my $opclass;
if ($operator eq "<->") {
$opclass = "vector_l2_ops";
} elsif ($operator eq "<#>") {
$opclass = "vector_ip_ops";
} else {
$opclass = "vector_cosine_ops";
}
$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v $opclass);");
if ($operator eq "<#>") {
test_recall(0.85, $operator);
} else {
test_recall(0.99, $operator);
}
}
done_testing();

View File

@@ -0,0 +1,103 @@
use strict;
use warnings;
use PostgresNode;
use TestLib;
use Test::More;
my $node;
my @queries = ();
my @expected;
my $limit = 20;
sub test_recall
{
my ($min, $operator) = @_;
my $correct = 0;
my $total = 0;
my $explain = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit;
));
like($explain, qr/Index Scan/);
for my $i (0 .. $#queries) {
my $actual = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit;
));
my @actual_ids = split("\n", $actual);
my %actual_set = map { $_ => 1 } @actual_ids;
my @expected_ids = split("\n", $expected[$i]);
foreach (@expected_ids) {
if (exists($actual_set{$_})) {
$correct++;
}
$total++;
}
}
cmp_ok($correct / $total, ">=", $min, $operator);
}
# Initialize node
$node = get_new_node('node');
$node->init;
$node->start;
# Create table
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));");
# Generate queries
for (1..20) {
my $r1 = rand();
my $r2 = rand();
my $r3 = rand();
push(@queries, "[$r1,$r2,$r3]");
}
# Check each index type
my @operators = ("<->", "<#>", "<=>");
foreach (@operators) {
my $operator = $_;
# Add index
my $opclass;
if ($operator eq "<->") {
$opclass = "vector_l2_ops";
} elsif ($operator eq "<#>") {
$opclass = "vector_ip_ops";
} else {
$opclass = "vector_cosine_ops";
}
$node->safe_psql("postgres", "CREATE INDEX idx ON tst USING hnsw (v $opclass);");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;"
);
# Get exact results
@expected = ();
foreach (@queries) {
my $res = $node->safe_psql("postgres", qq(
SET enable_indexscan = off;
SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;
));
push(@expected, $res);
}
if ($operator eq "<#>") {
test_recall(0.85, $operator);
} else {
test_recall(0.99, $operator);
}
$node->safe_psql("postgres", "DROP INDEX idx;");
$node->safe_psql("postgres", "TRUNCATE tst;");
}
done_testing();

View File

@@ -0,0 +1,54 @@
use strict;
use warnings;
use PostgresNode;
use TestLib;
use Test::More;
my $dim = 768;
my $array_sql = join(",", ('random()') x $dim);
# Initialize node
my $node = get_new_node('node');
$node->init;
$node->start;
# Create table and index
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
$node->safe_psql("postgres", "CREATE TABLE tst (v vector($dim));");
$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);");
$node->pgbench(
"--no-vacuum --client=5 --transactions=100",
0,
[qr{actually processed}],
[qr{^$}],
"concurrent INSERTs",
{
"007_inserts" => "INSERT INTO tst SELECT ARRAY[$array_sql] FROM generate_series(1, 10) i;"
}
);
sub idx_scan
{
# Stats do not update instantaneously
# https://www.postgresql.org/docs/current/monitoring-stats.html#MONITORING-STATS-VIEWS
sleep(1);
$node->safe_psql("postgres", "SELECT idx_scan FROM pg_stat_user_indexes WHERE indexrelid = 'tst_v_idx'::regclass;");
}
my $expected = 5 * 100 * 10;
my $count = $node->safe_psql("postgres", "SELECT COUNT(*) FROM tst;");
is($count, $expected);
is(idx_scan(), 0);
$count = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
SET hnsw.ef_search = 400;
SELECT COUNT(*) FROM (SELECT v FROM tst ORDER BY v <-> (SELECT v FROM tst LIMIT 1)) t;
));
is($count, 400);
is(idx_scan(), 1);
done_testing();