Added support for inline filtering with HNSW

This commit is contained in:
Andrew Kane
2024-10-09 19:02:40 -07:00
parent 3126fbdb6f
commit 3ccfab8f92
12 changed files with 567 additions and 128 deletions

View File

@@ -1,5 +1,6 @@
## 0.8.0 (unreleased)
- Added support for inline filtering with HNSW
- Added casts for arrays to `sparsevec`
- Improved cost estimation
- Improved performance of HNSW inserts and on-disk index builds

View File

@@ -439,6 +439,12 @@ Create an index on one [or more](https://www.postgresql.org/docs/current/indexes
CREATE INDEX ON items (category_id);
```
Or a composite HNSW index for approximate search (added in 0.8.0)
```sql
CREATE INDEX ON items USING hnsw (embedding vector_l2_ops, category_id);
```
Or a [partial index](https://www.postgresql.org/docs/current/indexes-partial.html) on the vector column for approximate search
```sql
@@ -1189,6 +1195,7 @@ Thanks to:
- [k-means++: The Advantage of Careful Seeding](https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf)
- [Concept Decompositions for Large Sparse Text Data using Clustering](https://www.cs.utexas.edu/users/inderjit/public_papers/concept_mlj.pdf)
- [Efficient and Robust Approximate Nearest Neighbor Search using Hierarchical Navigable Small World Graphs](https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf)
- [HQANN: Efficient and Robust Similarity Search for Hybrid Queries with Structured and Unstructured Constraints](https://arxiv.org/pdf/2207.07940.pdf)
## History

View File

@@ -24,3 +24,11 @@ CREATE CAST (double precision[] AS sparsevec)
CREATE CAST (numeric[] AS sparsevec)
WITH FUNCTION array_to_sparsevec(numeric[], integer, boolean) AS ASSIGNMENT;
CREATE FUNCTION hnsw_attribute_distance(integer, integer) RETURNS float8
AS 'MODULE_PATHNAME', 'hnsw_int4_attribute_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE OPERATOR CLASS vector_integer_ops
DEFAULT FOR TYPE integer USING hnsw AS
OPERATOR 2 = (integer, integer),
FUNCTION 4 hnsw_attribute_distance(integer, integer);

View File

@@ -916,3 +916,13 @@ CREATE OPERATOR CLASS sparsevec_l1_ops
OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops,
FUNCTION 1 l1_distance(sparsevec, sparsevec),
FUNCTION 3 hnsw_sparsevec_support(internal);
-- hnsw attributes
CREATE FUNCTION hnsw_attribute_distance(integer, integer) RETURNS float8
AS 'MODULE_PATHNAME', 'hnsw_int4_attribute_distance' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE OPERATOR CLASS vector_integer_ops
DEFAULT FOR TYPE integer USING hnsw AS
OPERATOR 2 = (integer, integer),
FUNCTION 4 hnsw_attribute_distance(integer, integer);

View File

@@ -227,13 +227,13 @@ hnswhandler(PG_FUNCTION_ARGS)
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
amroutine->amstrategies = 0;
amroutine->amsupport = 3;
amroutine->amsupport = 4;
amroutine->amoptsprocnum = 0;
amroutine->amcanorder = false;
amroutine->amcanorderbyop = true;
amroutine->amcanbackward = false; /* can change direction mid-scan */
amroutine->amcanunique = false;
amroutine->amcanmulticol = false;
amroutine->amcanmulticol = true;
amroutine->amoptionalkey = true;
amroutine->amsearcharray = false;
amroutine->amsearchnulls = false;
@@ -285,3 +285,17 @@ hnswhandler(PG_FUNCTION_ARGS)
PG_RETURN_POINTER(amroutine);
}
/*
* Get the distance between two int4 attributes
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(hnsw_int4_attribute_distance);
Datum
hnsw_int4_attribute_distance(PG_FUNCTION_ARGS)
{
int32 a = PG_GETARG_INT32(0);
int32 b = PG_GETARG_INT32(1);
double distance = ((double) a) - ((double) b);
PG_RETURN_FLOAT8(distance);
}

View File

@@ -19,6 +19,7 @@
#define HNSW_DISTANCE_PROC 1
#define HNSW_NORM_PROC 2
#define HNSW_TYPE_INFO_PROC 3
#define HNSW_ATTRIBUTE_DISTANCE_PROC 4
#define HNSW_VERSION 1
#define HNSW_MAGIC_NUMBER 0xA953A953
@@ -104,6 +105,8 @@
#define HnswPtrPointer(hp) (hp).ptr
#define HnswPtrOffset(hp) relptr_offset((hp).relptr)
#define HnswUseIndexTuple(index) (IndexRelationGetNumberOfAttributes(index) > 1)
/* Variables */
extern int hnsw_ef_search;
extern int hnsw_lock_tranche_id;
@@ -121,6 +124,7 @@ HnswPtrDeclare(HnswElementData, HnswElementRelptr, HnswElementPtr);
HnswPtrDeclare(HnswNeighborArray, HnswNeighborArrayRelptr, HnswNeighborArrayPtr);
HnswPtrDeclare(HnswNeighborArrayPtr, HnswNeighborsRelptr, HnswNeighborsPtr);
HnswPtrDeclare(char, DatumRelptr, DatumPtr);
HnswPtrDeclare(IndexTupleData, IndexTupleRelptr, IndexTuplePtr);
struct HnswElementData
{
@@ -136,6 +140,7 @@ struct HnswElementData
OffsetNumber neighborOffno;
BlockNumber neighborPage;
DatumPtr value;
IndexTuplePtr itup;
LWLock lock;
};
@@ -161,6 +166,7 @@ typedef struct HnswSearchCandidate
pairingheap_node w_node;
HnswElementPtr element;
double distance;
bool matches;
} HnswSearchCandidate;
/* HNSW index options */
@@ -256,15 +262,17 @@ typedef struct HnswBuildState
double reltuples;
/* Support functions */
FmgrInfo *procinfo;
FmgrInfo *procinfo[2];
FmgrInfo *normprocinfo;
Oid collation;
Oid *collation;
/* Variables */
HnswGraph graphData;
HnswGraph *graph;
double ml;
int maxLevel;
bool useIndexTuple;
TupleDesc tupdesc;
/* Memory */
MemoryContext graphCtx;
@@ -333,9 +341,9 @@ typedef struct HnswScanOpaqueData
MemoryContext tmpCtx;
/* Support functions */
FmgrInfo *procinfo;
FmgrInfo *procinfo[2];
FmgrInfo *normprocinfo;
Oid collation;
Oid *collation;
} HnswScanOpaqueData;
typedef HnswScanOpaqueData * HnswScanOpaque;
@@ -353,8 +361,8 @@ typedef struct HnswVacuumState
int efConstruction;
/* Support functions */
FmgrInfo *procinfo;
Oid collation;
FmgrInfo *procinfo[2];
Oid *collation;
/* Variables */
struct tidhash_hash *deleted;
@@ -375,29 +383,31 @@ bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
void HnswInitPage(Buffer buf, Page page);
void HnswInit(void);
List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement);
List *HnswSearchLayer(char *base, Datum q, IndexTuple qtup, ScanKeyData *keyData, List *ep, int ef, int lc, Relation index, FmgrInfo **procinfo, Oid *collation, int m, bool inserting, HnswElement skipElement, bool inMemory);
HnswElement HnswGetEntryPoint(Relation index);
void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint);
void *HnswAlloc(HnswAllocator * allocator, Size size);
HnswElement HnswInitElement(char *base, ItemPointer tid, int m, double ml, int maxLevel, HnswAllocator * alloc);
HnswElement HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno);
void HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, bool existing);
HnswSearchCandidate *HnswEntryCandidate(char *base, HnswElement em, Datum q, Relation rel, FmgrInfo *procinfo, Oid collation, bool loadVec);
void HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo **procinfo, Oid *collation, int m, int efConstruction, bool existing, bool inMemory);
HnswSearchCandidate *HnswEntryCandidate(char *base, HnswElement em, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation rel, FmgrInfo **procinfo, Oid *collation, bool loadVec, bool inMemory);
void HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum, bool building);
void HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m);
void HnswAddHeapTid(HnswElement element, ItemPointer heaptid);
HnswNeighborArray *HnswInitNeighborArray(int lm, HnswAllocator * allocator);
void HnswInitNeighbors(char *base, HnswElement element, int m, HnswAllocator * alloc);
bool HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, ItemPointer heap_tid, bool building);
void HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement e, int m, bool checkExisting, bool building);
void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec);
void HnswLoadElement(HnswElement element, double *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, double *maxDistance);
void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element);
void HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newElement, float distance, int lm, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation);
void HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo **procinfo, Oid *collation, HnswElement e, int m, bool checkExisting, bool building);
void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec, Relation index);
void HnswLoadElement(HnswElement element, double *distance, bool *matches, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfo, Oid *collation, bool loadVec, double *maxDistance);
void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element, bool useIndexTuple);
void HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newElement, float distance, int lm, int *updateIdx, Relation index, FmgrInfo **procinfo, Oid *collation);
bool HnswLoadNeighborTids(HnswElement element, ItemPointerData *indextids, Relation index, int m, int lm, int lc);
void HnswInitLockTranche(void);
const HnswTypeInfo *HnswGetTypeInfo(Relation index);
PGDLLEXPORT void HnswParallelBuildMain(dsm_segment *seg, shm_toc *toc);
void HnswInitProcinfo(FmgrInfo **procinfo, Relation index);
bool HnswIndexTupleIsEqual(IndexTuple a, IndexTuple b, TupleDesc tupdesc);
/* Index access methods */
IndexBuildResult *hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo);

View File

@@ -148,6 +148,7 @@ CreateGraphPages(HnswBuildState * buildstate)
Page page;
HnswElementPtr iter = buildstate->graph->head;
char *base = buildstate->hnswarea;
bool useIndexTuple = buildstate->useIndexTuple;
/* Calculate sizes */
maxSize = HNSW_MAX_SIZE;
@@ -167,7 +168,6 @@ CreateGraphPages(HnswBuildState * buildstate)
Size etupSize;
Size ntupSize;
Size combinedSize;
Pointer valuePtr = HnswPtrAccess(base, element->value);
/* Update iterator */
iter = element->next;
@@ -176,7 +176,7 @@ CreateGraphPages(HnswBuildState * buildstate)
MemSet(etup, 0, HNSW_TUPLE_ALLOC_SIZE);
/* Calculate sizes */
etupSize = HNSW_ELEMENT_TUPLE_SIZE(VARSIZE_ANY(valuePtr));
etupSize = HNSW_ELEMENT_TUPLE_SIZE(useIndexTuple ? IndexTupleSize(HnswPtrAccess(base, element->itup)) : VARSIZE_ANY(HnswPtrAccess(base, element->value)));
ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, buildstate->m);
combinedSize = etupSize + ntupSize + sizeof(ItemIdData);
@@ -186,7 +186,7 @@ CreateGraphPages(HnswBuildState * buildstate)
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
errmsg("index tuple too large")));
HnswSetElementTuple(base, etup, element);
HnswSetElementTuple(base, etup, element, useIndexTuple);
/* Keep element and neighbors on the same page if possible */
if (PageGetFreeSpace(page) < etupSize || (combinedSize <= maxSize && PageGetFreeSpace(page) < combinedSize))
@@ -327,20 +327,29 @@ AddDuplicateInMemory(HnswElement element, HnswElement dup)
* Find duplicate element
*/
static bool
FindDuplicateInMemory(char *base, HnswElement element)
FindDuplicateInMemory(char *base, HnswElement element, bool useIndexTuple, TupleDesc tupdesc)
{
HnswNeighborArray *neighbors = HnswGetNeighbors(base, element, 0);
Datum value = HnswGetValue(base, element);
IndexTuple itup = HnswPtrAccess(base, element->itup);
for (int i = 0; i < neighbors->length; i++)
{
HnswCandidate *neighbor = &neighbors->items[i];
HnswElement neighborElement = HnswPtrAccess(base, neighbor->element);
Datum neighborValue = HnswGetValue(base, neighborElement);
/* Exit early since ordered by distance */
if (!datumIsEqual(value, neighborValue, false, -1))
return false;
if (useIndexTuple)
{
/* Exit early since ordered by distance */
if (!HnswIndexTupleIsEqual(itup, HnswPtrAccess(base, neighborElement->itup), tupdesc))
return false;
}
else
{
/* Exit early since ordered by distance */
if (!datumIsEqual(value, HnswGetValue(base, neighborElement), false, -1))
return false;
}
/* Check for space */
if (AddDuplicateInMemory(element, neighborElement))
@@ -366,7 +375,7 @@ AddElementInMemory(char *base, HnswGraph * graph, HnswElement element)
* Update neighbors
*/
static void
UpdateNeighborsInMemory(char *base, FmgrInfo *procinfo, Oid collation, HnswElement e, int m)
UpdateNeighborsInMemory(char *base, Relation index, FmgrInfo **procinfo, Oid *collation, HnswElement e, int m)
{
for (int lc = e->level; lc >= 0; lc--)
{
@@ -388,7 +397,7 @@ UpdateNeighborsInMemory(char *base, FmgrInfo *procinfo, Oid collation, HnswEleme
Assert(neighborElement);
LWLockAcquire(&neighborElement->lock, LW_EXCLUSIVE);
HnswUpdateConnection(base, HnswGetNeighbors(base, neighborElement, lc), e, hc->distance, lm, NULL, NULL, procinfo, collation);
HnswUpdateConnection(base, HnswGetNeighbors(base, neighborElement, lc), e, hc->distance, lm, NULL, index, procinfo, collation);
LWLockRelease(&neighborElement->lock);
}
}
@@ -398,20 +407,20 @@ UpdateNeighborsInMemory(char *base, FmgrInfo *procinfo, Oid collation, HnswEleme
* Update graph in memory
*/
static void
UpdateGraphInMemory(FmgrInfo *procinfo, Oid collation, HnswElement element, int m, int efConstruction, HnswElement entryPoint, HnswBuildState * buildstate)
UpdateGraphInMemory(FmgrInfo **procinfo, Oid *collation, HnswElement element, int m, int efConstruction, HnswElement entryPoint, HnswBuildState * buildstate)
{
HnswGraph *graph = buildstate->graph;
char *base = buildstate->hnswarea;
/* Look for duplicate */
if (FindDuplicateInMemory(base, element))
if (FindDuplicateInMemory(base, element, buildstate->useIndexTuple, buildstate->tupdesc))
return;
/* Add element */
AddElementInMemory(base, graph, element);
/* Update neighbors */
UpdateNeighborsInMemory(base, procinfo, collation, element, m);
UpdateNeighborsInMemory(base, buildstate->index, procinfo, collation, element, m);
/* Update entry point if needed (already have lock) */
if (entryPoint == NULL || element->level > entryPoint->level)
@@ -424,8 +433,9 @@ UpdateGraphInMemory(FmgrInfo *procinfo, Oid collation, HnswElement element, int
static void
InsertTupleInMemory(HnswBuildState * buildstate, HnswElement element)
{
FmgrInfo *procinfo = buildstate->procinfo;
Oid collation = buildstate->collation;
Relation index = buildstate->index;
FmgrInfo **procinfo = buildstate->procinfo;
Oid *collation = buildstate->collation;
HnswGraph *graph = buildstate->graph;
HnswElement entryPoint;
LWLock *entryLock = &graph->entryLock;
@@ -458,7 +468,7 @@ InsertTupleInMemory(HnswBuildState * buildstate, HnswElement element)
}
/* Find neighbors for element */
HnswFindElementNeighbors(base, element, entryPoint, NULL, procinfo, collation, m, efConstruction, false);
HnswFindElementNeighbors(base, element, entryPoint, index, procinfo, collation, m, efConstruction, false, true);
/* Update graph in memory */
UpdateGraphInMemory(procinfo, collation, element, m, efConstruction, entryPoint, buildstate);
@@ -481,6 +491,11 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn
Pointer valuePtr;
LWLock *flushLock = &graph->flushLock;
char *base = buildstate->hnswarea;
bool useIndexTuple = buildstate->useIndexTuple;
TupleDesc tupdesc = buildstate->tupdesc;
IndexTuple itup;
Size itupSize;
IndexTuple itupPtr;
/* Detoast once for all calls */
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
@@ -492,10 +507,10 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn
/* Normalize if needed */
if (buildstate->normprocinfo != NULL)
{
if (!HnswCheckNorm(buildstate->normprocinfo, buildstate->collation, value))
if (!HnswCheckNorm(buildstate->normprocinfo, buildstate->collation[0], value))
return false;
value = HnswNormValue(typeInfo, buildstate->collation, value);
value = HnswNormValue(typeInfo, buildstate->collation[0], value);
}
/* Get datum size */
@@ -546,7 +561,17 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn
/* Ok, we can proceed to allocate the element */
element = HnswInitElement(base, heaptid, buildstate->m, buildstate->ml, buildstate->maxLevel, allocator);
valuePtr = HnswAlloc(allocator, valueSize);
if (useIndexTuple)
{
/* TODO fix */
values[0] = value;
itup = index_form_tuple(tupdesc, values, isnull);
itupSize = IndexTupleSize(itup);
itupPtr = HnswAlloc(allocator, itupSize);
}
else
valuePtr = HnswAlloc(allocator, valueSize);
/*
* We have now allocated the space needed for the element, so we don't
@@ -556,8 +581,19 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn
LWLockRelease(&graph->allocatorLock);
/* Copy the datum */
memcpy(valuePtr, DatumGetPointer(value), valueSize);
HnswPtrStore(base, element->value, valuePtr);
if (useIndexTuple)
{
bool unused;
memcpy(itupPtr, itup, itupSize);
HnswPtrStore(base, element->itup, itupPtr);
HnswPtrStore(base, element->value, DatumGetPointer(index_getattr(itupPtr, 1, tupdesc, &unused)));
}
else
{
memcpy(valuePtr, DatumGetPointer(value), valueSize);
HnswPtrStore(base, element->value, valuePtr);
}
/* Create a lock for the element */
LWLockInitialize(&element->lock, hnsw_lock_tranche_id);
@@ -684,6 +720,19 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("type not supported for hnsw index")));
/* TODO See if needed */
if (IndexRelationGetNumberOfKeyAttributes(index) > 2)
elog(ERROR, "index cannot have more than two columns");
if (!OidIsValid(index_getprocid(index, 1, HNSW_DISTANCE_PROC)))
elog(ERROR, "first column must be a vector");
for (int i = 1; i < IndexRelationGetNumberOfKeyAttributes(index); i++)
{
if (!OidIsValid(index_getprocid(index, i + 1, HNSW_ATTRIBUTE_DISTANCE_PROC)))
elog(ERROR, "column %d cannot be a vector", i + 1);
}
/* Require column to have dimensions to be indexed */
if (buildstate->dimensions < 0)
ereport(ERROR,
@@ -704,14 +753,16 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
buildstate->indtuples = 0;
/* Get support functions */
buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
HnswInitProcinfo(buildstate->procinfo, index);
buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
buildstate->collation = index->rd_indcollation[0];
buildstate->collation = index->rd_indcollation;
InitGraph(&buildstate->graphData, NULL, (Size) maintenance_work_mem * 1024L);
buildstate->graph = &buildstate->graphData;
buildstate->ml = HnswGetMl(buildstate->m);
buildstate->maxLevel = HnswGetMaxLevel(buildstate->m);
buildstate->useIndexTuple = HnswUseIndexTuple(index);
buildstate->tupdesc = RelationGetDescr(index);
buildstate->graphCtx = GenerationContextCreate(CurrentMemoryContext,
"Hnsw build graph context",

View File

@@ -154,9 +154,10 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B
OffsetNumber freeNeighborOffno = InvalidOffsetNumber;
BlockNumber newInsertPage = InvalidBlockNumber;
char *base = NULL;
bool useIndexTuple = HnswUseIndexTuple(index);
/* Calculate sizes */
etupSize = HNSW_ELEMENT_TUPLE_SIZE(VARSIZE_ANY(HnswPtrAccess(base, e->value)));
etupSize = HNSW_ELEMENT_TUPLE_SIZE(useIndexTuple ? IndexTupleSize(HnswPtrAccess(base, e->itup)) : VARSIZE_ANY(HnswPtrAccess(base, e->value)));
ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(e->level, m);
combinedSize = etupSize + ntupSize + sizeof(ItemIdData);
maxSize = HNSW_MAX_SIZE;
@@ -164,7 +165,7 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B
/* Prepare element tuple */
etup = palloc0(etupSize);
HnswSetElementTuple(base, etup, e);
HnswSetElementTuple(base, etup, e, useIndexTuple);
/* Prepare neighbor tuple */
ntup = palloc0(ntupSize);
@@ -368,7 +369,7 @@ HnswLoadNeighbors(HnswElement element, Relation index, int m, int lm, int lc)
* Load elements for insert
*/
static void
LoadElementsForInsert(HnswNeighborArray * neighbors, Datum q, int *idx, Relation index, FmgrInfo *procinfo, Oid collation)
LoadElementsForInsert(HnswNeighborArray * neighbors, Datum q, IndexTuple qtup, int *idx, Relation index, FmgrInfo **procinfo, Oid *collation)
{
char *base = NULL;
@@ -377,8 +378,9 @@ LoadElementsForInsert(HnswNeighborArray * neighbors, Datum q, int *idx, Relation
HnswCandidate *hc = &neighbors->items[i];
HnswElement element = HnswPtrAccess(base, hc->element);
double distance;
bool matches;
HnswLoadElement(element, &distance, &q, index, procinfo, collation, true, NULL);
HnswLoadElement(element, &distance, &matches, &q, qtup, NULL, index, procinfo, collation, true, NULL);
hc->distance = distance;
/* Prune element if being deleted */
@@ -394,7 +396,7 @@ LoadElementsForInsert(HnswNeighborArray * neighbors, Datum q, int *idx, Relation
* Get update index
*/
static int
GetUpdateIndex(HnswElement element, HnswElement newElement, float distance, int m, int lm, int lc, Relation index, FmgrInfo *procinfo, Oid collation, MemoryContext updateCtx)
GetUpdateIndex(HnswElement element, HnswElement newElement, float distance, int m, int lm, int lc, Relation index, FmgrInfo **procinfo, Oid *collation, MemoryContext updateCtx)
{
char *base = NULL;
int idx = -1;
@@ -420,8 +422,9 @@ GetUpdateIndex(HnswElement element, HnswElement newElement, float distance, int
else
{
Datum q = HnswGetValue(base, element);
IndexTuple qtup = HnswPtrAccess(base, element->itup);;
LoadElementsForInsert(neighbors, q, &idx, index, procinfo, collation);
LoadElementsForInsert(neighbors, q, qtup, &idx, index, procinfo, collation);
if (idx == -1)
HnswUpdateConnection(base, neighbors, newElement, distance, lm, &idx, index, procinfo, collation);
@@ -529,7 +532,7 @@ UpdateNeighborOnDisk(HnswElement element, HnswElement newElement, int idx, int m
* Update neighbors
*/
void
HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement e, int m, bool checkExisting, bool building)
HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo **procinfo, Oid *collation, HnswElement e, int m, bool checkExisting, bool building)
{
char *base = NULL;
@@ -630,16 +633,26 @@ FindDuplicateOnDisk(Relation index, HnswElement element, bool building)
char *base = NULL;
HnswNeighborArray *neighbors = HnswGetNeighbors(base, element, 0);
Datum value = HnswGetValue(base, element);
IndexTuple itup = HnswPtrAccess(base, element->itup);
TupleDesc tupdesc = RelationGetDescr(index);
for (int i = 0; i < neighbors->length; i++)
{
HnswCandidate *neighbor = &neighbors->items[i];
HnswElement neighborElement = HnswPtrAccess(base, neighbor->element);
Datum neighborValue = HnswGetValue(base, neighborElement);
/* Exit early since ordered by distance */
if (!datumIsEqual(value, neighborValue, false, -1))
return false;
if (HnswUseIndexTuple(index))
{
/* Exit early since ordered by distance */
if (!HnswIndexTupleIsEqual(itup, HnswPtrAccess(base, neighborElement->itup), tupdesc))
return false;
}
else
{
/* Exit early since ordered by distance */
if (!datumIsEqual(value, HnswGetValue(base, neighborElement), false, -1))
return false;
}
if (AddDuplicateOnDisk(index, element, neighborElement, building))
return true;
@@ -652,7 +665,7 @@ FindDuplicateOnDisk(Relation index, HnswElement element, bool building)
* Update graph on disk
*/
static void
UpdateGraphOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement element, int m, int efConstruction, HnswElement entryPoint, bool building)
UpdateGraphOnDisk(Relation index, FmgrInfo **procinfo, Oid *collation, HnswElement element, int m, int efConstruction, HnswElement entryPoint, bool building)
{
BlockNumber newInsertPage = InvalidBlockNumber;
@@ -685,11 +698,13 @@ HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull,
HnswElement element;
int m;
int efConstruction = HnswGetEfConstruction(index);
FmgrInfo *procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
Oid collation = index->rd_indcollation[0];
FmgrInfo *procinfo[2];
Oid *collation = index->rd_indcollation;
LOCKMODE lockmode = ShareLock;
char *base = NULL;
HnswInitProcinfo(procinfo, index);
/*
* Get a shared lock. This allows vacuum to ensure no in-flight inserts
* before repairing graph. Use a page lock so it does not interfere with
@@ -702,7 +717,23 @@ HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull,
/* Create an element */
element = HnswInitElement(base, heap_tid, m, HnswGetMl(m), HnswGetMaxLevel(m), NULL);
HnswPtrStore(base, element->value, DatumGetPointer(value));
if (HnswUseIndexTuple(index))
{
/* TODO no toast */
TupleDesc tupdesc = RelationGetDescr(index);
IndexTuple itup;
bool unused;
/* TODO fix */
values[0] = value;
itup = index_form_tuple(tupdesc, values, isnull);
HnswPtrStore(base, element->itup, itup);
HnswPtrStore(base, element->value, DatumGetPointer(index_getattr(itup, 1, tupdesc, &unused)));
}
else
HnswPtrStore(base, element->value, DatumGetPointer(value));
/* Prevent concurrent inserts when likely updating entry point */
if (entryPoint == NULL || element->level > entryPoint->level)
@@ -719,7 +750,7 @@ HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull,
}
/* Find neighbors for element */
HnswFindElementNeighbors(base, element, entryPoint, index, procinfo, collation, m, efConstruction, false);
HnswFindElementNeighbors(base, element, entryPoint, index, procinfo, collation, m, efConstruction, false, false);
/* Update graph on disk */
UpdateGraphOnDisk(index, procinfo, collation, element, m, efConstruction, entryPoint, building);
@@ -739,7 +770,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti
Datum value;
const HnswTypeInfo *typeInfo = HnswGetTypeInfo(index);
FmgrInfo *normprocinfo;
Oid collation = index->rd_indcollation[0];
Oid *collation = index->rd_indcollation;
/* Detoast once for all calls */
value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
@@ -752,10 +783,10 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti
normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
if (normprocinfo != NULL)
{
if (!HnswCheckNorm(normprocinfo, collation, value))
if (!HnswCheckNorm(normprocinfo, collation[0], value))
return;
value = HnswNormValue(typeInfo, collation, value);
value = HnswNormValue(typeInfo, collation[0], value);
}
HnswInsertTupleOnDisk(index, value, values, isnull, heap_tid, false);

View File

@@ -15,13 +15,15 @@ GetScanItems(IndexScanDesc scan, Datum q)
{
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
Relation index = scan->indexRelation;
FmgrInfo *procinfo = so->procinfo;
Oid collation = so->collation;
FmgrInfo **procinfo = so->procinfo;
Oid *collation = so->collation;
List *ep;
List *w;
int m;
HnswElement entryPoint;
char *base = NULL;
bool inMemory = false;
ScanKeyData *keyData = scan->keyData;
/* Get m and entry point */
HnswGetMetaPageInfo(index, &m, &entryPoint);
@@ -29,15 +31,15 @@ GetScanItems(IndexScanDesc scan, Datum q)
if (entryPoint == NULL)
return NIL;
ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, procinfo, collation, false));
ep = list_make1(HnswEntryCandidate(base, entryPoint, q, NULL, keyData, index, procinfo, collation, false, inMemory));
for (int lc = entryPoint->level; lc >= 1; lc--)
{
w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL);
w = HnswSearchLayer(base, q, NULL, keyData, ep, 1, lc, index, procinfo, collation, m, false, NULL, inMemory);
ep = w;
}
return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL);
return HnswSearchLayer(base, q, NULL, keyData, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL, inMemory);
}
/*
@@ -61,7 +63,7 @@ GetScanValue(IndexScanDesc scan)
/* Normalize if needed */
if (so->normprocinfo != NULL)
value = HnswNormValue(so->typeInfo, so->collation, value);
value = HnswNormValue(so->typeInfo, so->collation[0], value);
}
return value;
@@ -86,9 +88,9 @@ hnswbeginscan(Relation index, int nkeys, int norderbys)
ALLOCSET_DEFAULT_SIZES);
/* Set support functions */
so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
HnswInitProcinfo(so->procinfo, index);
so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
so->collation = index->rd_indcollation[0];
so->collation = index->rd_indcollation;
scan->opaque = so;
@@ -173,7 +175,7 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir)
ItemPointer heaptid;
/* Move to next element if no valid heap TIDs */
if (element->heaptidsLength == 0)
if (!hc->matches || element->heaptidsLength == 0)
{
so->w = list_delete_last(so->w);
continue;

View File

@@ -153,6 +153,18 @@ HnswOptionalProcInfo(Relation index, uint16 procnum)
return index_getprocinfo(index, 1, procnum);
}
/*
* Init procinfo
*/
void
HnswInitProcinfo(FmgrInfo **procinfo, Relation index)
{
procinfo[0] = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
if (IndexRelationGetNumberOfKeyAttributes(index) > 1)
procinfo[1] = index_getprocinfo(index, 2, HNSW_ATTRIBUTE_DISTANCE_PROC);
}
/*
* Normalize value
*/
@@ -171,6 +183,37 @@ HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value)
return DatumGetFloat8(FunctionCall1Coll(procinfo, collation, value)) > 0;
}
/*
* Check if index tuples are equal
*/
bool
HnswIndexTupleIsEqual(IndexTuple a, IndexTuple b, TupleDesc tupdesc)
{
for (int i = 0; i < tupdesc->natts; i++)
{
bool nullA;
bool nullB;
Datum datumA = index_getattr(a, i + 1, tupdesc, &nullA);
Datum datumB = index_getattr(b, i + 1, tupdesc, &nullB);
if (nullA || nullB)
{
if (nullA != nullB)
return false;
}
else
{
Form_pg_attribute att = TupleDescAttr(tupdesc, i);
if (!datumIsEqual(datumA, datumB, att->attbyval, att->attlen))
return false;
}
}
return true;
}
/*
* New buffer
*/
@@ -257,6 +300,7 @@ HnswInitElement(char *base, ItemPointer heaptid, int m, double ml, int maxLevel,
HnswInitNeighbors(base, element, m, allocator);
HnswPtrStore(base, element->value, (Pointer) NULL);
HnswPtrStore(base, element->itup, (IndexTuple) NULL);
return element;
}
@@ -283,6 +327,7 @@ HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno)
element->offno = offno;
HnswPtrStore(base, element->neighbors, (HnswNeighborArrayPtr *) NULL);
HnswPtrStore(base, element->value, (Pointer) NULL);
HnswPtrStore(base, element->itup, (IndexTuple) NULL);
return element;
}
@@ -398,10 +443,8 @@ HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, Bloc
* Set element tuple, except for neighbor info
*/
void
HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element)
HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element, bool useIndexTuple)
{
Pointer valuePtr = HnswPtrAccess(base, element->value);
etup->type = HNSW_ELEMENT_TUPLE_TYPE;
etup->level = element->level;
etup->deleted = 0;
@@ -412,7 +455,19 @@ HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element)
else
ItemPointerSetInvalid(&etup->heaptids[i]);
}
memcpy(&etup->data, valuePtr, VARSIZE_ANY(valuePtr));
if (useIndexTuple)
{
IndexTuple itup = HnswPtrAccess(base, element->itup);
memcpy(&etup->data, itup, IndexTupleSize(itup));
}
else
{
Pointer valuePtr = HnswPtrAccess(base, element->value);
memcpy(&etup->data, valuePtr, VARSIZE_ANY(valuePtr));
}
}
/*
@@ -453,7 +508,7 @@ HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m)
* Load an element from a tuple
*/
void
HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec)
HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec, Relation index)
{
element->level = etup->level;
element->deleted = etup->deleted;
@@ -476,26 +531,128 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe
if (loadVec)
{
char *base = NULL;
Datum value = datumCopy(PointerGetDatum(&etup->data), false, -1);
HnswPtrStore(base, element->value, DatumGetPointer(value));
if (HnswUseIndexTuple(index))
{
IndexTuple itup = CopyIndexTuple((IndexTuple) &etup->data);
TupleDesc tupdesc = RelationGetDescr(index);
bool unused;
HnswPtrStore(base, element->itup, itup);
HnswPtrStore(base, element->value, DatumGetPointer(index_getattr(itup, 1, tupdesc, &unused)));
}
else
{
Datum value = datumCopy(PointerGetDatum(&etup->data), false, -1);
HnswPtrStore(base, element->value, DatumGetPointer(value));
}
}
}
/*
* Get the attribute distance
*/
static inline double
AttributeDistance(double e)
{
/* TODO Better bias */
/* must be >> max(w * g) + 1 / log10(2) */
double bias = 4.32;
return e > 0 ? bias - 1.0 / log10(e + 1) : 0;
}
/*
* Calculate the distance between values
*/
static inline double
HnswGetDistance(Datum a, Datum b, FmgrInfo *procinfo, Oid collation)
static double
HnswGetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfo, Oid *collation, bool *matches)
{
return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, a, b));
double g;
if (DatumGetPointer(q) == NULL)
g = 0;
else
g = DatumGetFloat8(FunctionCall2Coll(procinfo[0], collation[0], q, vec));
Assert(PointerIsValid(matches));
*matches = true;
if (IndexRelationGetNumberOfKeyAttributes(index) > 1)
{
double w = 0.25;
double e = 0.0;
TupleDesc tupdesc = RelationGetDescr(index);
if (keyData)
{
/* TODO need to pass length of key data */
int keyCount = 1;
for (int i = 0; i < keyCount; i++)
{
ScanKey key = &keyData[i];
bool isnull;
Datum value = index_getattr(itup, key->sk_attno, tupdesc, &isnull);
bool attnull = key->sk_flags & SK_ISNULL;
if (isnull || attnull)
{
if (isnull != attnull)
{
e += 1000;
*matches = false;
}
}
else if (!DatumGetBool(FunctionCall2Coll(&key->sk_func, key->sk_collation, value, key->sk_argument)))
{
double ei = fabs(DatumGetFloat8(FunctionCall2Coll(procinfo[key->sk_attno - 1], collation[key->sk_attno - 1], value, key->sk_argument)));
if (ei > 0)
e += ei;
else
/* Distance is zero for inequality */
e += 1000;
*matches = false;
}
}
return w * g + AttributeDistance(e);
}
else if (qtup)
{
int keyCount = IndexRelationGetNumberOfKeyAttributes(index) - 1;
for (int i = 0; i < keyCount; i++)
{
bool isnull;
bool attnull;
Datum value = index_getattr(itup, i + 2, tupdesc, &isnull);
Datum value2 = index_getattr(qtup, i + 2, tupdesc, &attnull);
if (isnull || attnull)
{
if (isnull != attnull)
e += 1000;
}
else
e += fabs(DatumGetFloat8(FunctionCall2Coll(procinfo[i + 1], collation[i + 1], value, value2)));
}
return w * g + AttributeDistance(e);
}
}
return g;
}
/*
* Load an element and optionally get its distance from q
*/
static void
HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, double *maxDistance, HnswElement * element)
HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, bool *matches, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfo, Oid *collation, bool loadVec, double *maxDistance, HnswElement * element)
{
Buffer buf;
Page page;
@@ -513,10 +670,23 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Dat
/* Calculate distance */
if (distance != NULL)
{
if (DatumGetPointer(*q) == NULL)
*distance = 0;
IndexTuple itup = NULL;
Datum value;
if (HnswUseIndexTuple(index))
{
TupleDesc tupdesc = RelationGetDescr(index);
bool unused;
itup = (IndexTuple) &etup->data;
value = index_getattr(itup, 1, tupdesc, &unused);
}
else
*distance = HnswGetDistance(*q, PointerGetDatum(&etup->data), procinfo, collation);
{
value = PointerGetDatum(&etup->data);
}
*distance = HnswGetDistance(itup, value, *q, qtup, keyData, index, procinfo, collation, matches);
}
/* Load element */
@@ -525,7 +695,7 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Dat
if (*element == NULL)
*element = HnswInitElementFromBlock(blkno, offno);
HnswLoadElementFromTuple(*element, etup, true, loadVec);
HnswLoadElementFromTuple(*element, etup, true, loadVec, index);
}
UnlockReleaseBuffer(buf);
@@ -535,35 +705,36 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Dat
* Load an element and optionally get its distance from q
*/
void
HnswLoadElement(HnswElement element, double *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, double *maxDistance)
HnswLoadElement(HnswElement element, double *distance, bool *matches, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfo, Oid *collation, bool loadVec, double *maxDistance)
{
HnswLoadElementImpl(element->blkno, element->offno, distance, q, index, procinfo, collation, loadVec, maxDistance, &element);
HnswLoadElementImpl(element->blkno, element->offno, distance, matches, q, qtup, keyData, index, procinfo, collation, loadVec, maxDistance, &element);
}
/*
* Get the distance for an element
*/
static double
GetElementDistance(char *base, HnswElement element, Datum q, FmgrInfo *procinfo, Oid collation)
GetElementDistance(char *base, HnswElement element, bool *matches, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfo, Oid *collation)
{
Datum value = HnswGetValue(base, element);
IndexTuple itup = HnswPtrAccess(base, element->itup);
return HnswGetDistance(q, value, procinfo, collation);
return HnswGetDistance(itup, value, q, qtup, keyData, index, procinfo, collation, matches);
}
/*
* Create a candidate for the entry point
*/
HnswSearchCandidate *
HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec)
HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfo, Oid *collation, bool loadVec, bool inMemory)
{
HnswSearchCandidate *sc = palloc(sizeof(HnswSearchCandidate));
HnswPtrStore(base, sc->element, entryPoint);
if (index == NULL)
sc->distance = GetElementDistance(base, entryPoint, q, procinfo, collation);
if (inMemory)
sc->distance = GetElementDistance(base, entryPoint, &sc->matches, q, qtup, keyData, index, procinfo, collation);
else
HnswLoadElement(entryPoint, &sc->distance, &q, index, procinfo, collation, loadVec, NULL);
HnswLoadElement(entryPoint, &sc->distance, &sc->matches, &q, qtup, keyData, index, procinfo, collation, loadVec, NULL);
return sc;
}
@@ -604,9 +775,9 @@ CompareFurthestCandidates(const pairingheap_node *a, const pairingheap_node *b,
* Init visited
*/
static inline void
InitVisited(char *base, visited_hash * v, Relation index, int ef, int m)
InitVisited(char *base, visited_hash * v, bool inMemory, int ef, int m)
{
if (index != NULL)
if (!inMemory)
v->tids = tidhash_create(CurrentMemoryContext, ef * m * 2, NULL);
else if (base != NULL)
v->offsets = offsethash_create(CurrentMemoryContext, ef * m * 2, NULL);
@@ -618,9 +789,9 @@ InitVisited(char *base, visited_hash * v, Relation index, int ef, int m)
* Add to visited
*/
static inline void
AddToVisited(char *base, visited_hash * v, HnswElementPtr elementPtr, Relation index, bool *found)
AddToVisited(char *base, visited_hash * v, HnswElementPtr elementPtr, bool inMemory, bool *found)
{
if (index != NULL)
if (!inMemory)
{
HnswElement element = HnswPtrAccess(base, elementPtr);
ItemPointerData indextid;
@@ -681,7 +852,7 @@ HnswLoadUnvisitedFromMemory(char *base, HnswElement element, HnswUnvisited * unv
HnswCandidate *hc = &localNeighborhood->items[i];
bool found;
AddToVisited(base, v, hc->element, NULL, &found);
AddToVisited(base, v, hc->element, true, &found);
if (!found)
unvisited[(*unvisitedLength)++].element = HnswPtrAccess(base, hc->element);
@@ -752,7 +923,7 @@ HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *u
* Algorithm 2 from paper
*/
List *
HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement)
HnswSearchLayer(char *base, Datum q, IndexTuple qtup, ScanKeyData *keyData, List *ep, int ef, int lc, Relation index, FmgrInfo **procinfo, Oid *collation, int m, bool inserting, HnswElement skipElement, bool inMemory)
{
List *w = NIL;
pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL);
@@ -765,11 +936,13 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
int lm = HnswGetLayerM(m, lc);
HnswUnvisited *unvisited = palloc(lm * sizeof(HnswUnvisited));
int unvisitedLength;
uint64 additional = 0;
uint64 maxAdditional = keyData && lc == 0 ? 10000 : 0;
InitVisited(base, &v, index, ef, m);
InitVisited(base, &v, inMemory, ef, m);
/* Create local memory for neighborhood if needed */
if (index == NULL)
if (inMemory)
{
neighborhoodSize = HNSW_NEIGHBOR_ARRAY_SIZE(lm);
localNeighborhood = palloc(neighborhoodSize);
@@ -781,11 +954,15 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
HnswSearchCandidate *sc = (HnswSearchCandidate *) lfirst(lc2);
bool found;
AddToVisited(base, &v, sc->element, index, &found);
AddToVisited(base, &v, sc->element, inMemory, &found);
pairingheap_add(C, &sc->c_node);
pairingheap_add(W, &sc->w_node);
/* Do not count elements that do not match filter towards ef */
if (!sc->matches && ++additional <= maxAdditional)
continue;
/*
* Do not count elements being deleted towards ef when vacuuming. It
* would be ideal to do this for inserts as well, but this could
@@ -806,7 +983,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
cElement = HnswPtrAccess(base, c->element);
if (index == NULL)
if (inMemory)
HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, &v, lc, localNeighborhood, neighborhoodSize);
else
HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, &v, index, m, lm, lc);
@@ -816,14 +993,15 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
HnswElement eElement;
HnswSearchCandidate *e;
double eDistance;
bool eMatches;
bool alwaysAdd = wlen < ef;
f = HnswGetSearchCandidate(w_node, pairingheap_first(W));
if (index == NULL)
if (inMemory)
{
eElement = unvisited[i].element;
eDistance = GetElementDistance(base, eElement, q, procinfo, collation);
eDistance = GetElementDistance(base, eElement, &eMatches, q, qtup, keyData, index, procinfo, collation);
}
else
{
@@ -833,7 +1011,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
/* Avoid any allocations if not adding */
eElement = NULL;
HnswLoadElementImpl(blkno, offno, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd ? NULL : &f->distance, &eElement);
HnswLoadElementImpl(blkno, offno, &eDistance, &eMatches, &q, qtup, keyData, index, procinfo, collation, inserting, alwaysAdd ? NULL : &f->distance, &eElement);
if (eElement == NULL)
continue;
@@ -852,6 +1030,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
e = palloc(sizeof(HnswSearchCandidate));
HnswPtrStore(base, e->element, eElement);
e->distance = eDistance;
e->matches = eMatches;
pairingheap_add(C, &e->c_node);
pairingheap_add(W, &e->w_node);
@@ -862,6 +1041,10 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
*/
if (CountElement(skipElement, eElement))
{
/* Do not count elements that do not match filter towards ef */
if (!e->matches && ++additional <= maxAdditional)
continue;
wlen++;
/* No need to decrement wlen */
@@ -934,10 +1117,11 @@ CompareCandidateDistancesOffset(const ListCell *a, const ListCell *b)
* Check if an element is closer to q than any element from R
*/
static bool
CheckElementCloser(char *base, HnswCandidate * e, List *r, FmgrInfo *procinfo, Oid collation)
CheckElementCloser(char *base, HnswCandidate * e, List *r, Relation index, FmgrInfo **procinfo, Oid *collation)
{
HnswElement eElement = HnswPtrAccess(base, e->element);
Datum eValue = HnswGetValue(base, eElement);
IndexTuple etup = HnswPtrAccess(base, eElement->itup);
ListCell *lc2;
foreach(lc2, r)
@@ -945,7 +1129,9 @@ CheckElementCloser(char *base, HnswCandidate * e, List *r, FmgrInfo *procinfo, O
HnswCandidate *ri = lfirst(lc2);
HnswElement riElement = HnswPtrAccess(base, ri->element);
Datum riValue = HnswGetValue(base, riElement);
float distance = HnswGetDistance(eValue, riValue, procinfo, collation);
IndexTuple ritup = HnswPtrAccess(base, riElement->itup);
bool matches;
float distance = HnswGetDistance(etup, eValue, riValue, ritup, NULL, index, procinfo, collation, &matches);
if (distance <= e->distance)
return false;
@@ -958,7 +1144,7 @@ CheckElementCloser(char *base, HnswCandidate * e, List *r, FmgrInfo *procinfo, O
* Algorithm 4 from paper
*/
static List *
SelectNeighbors(char *base, List *c, int lm, FmgrInfo *procinfo, Oid collation, bool *closerSet, HnswCandidate * newCandidate, HnswCandidate * *pruned, bool sortCandidates)
SelectNeighbors(char *base, List *c, int lm, Relation index, FmgrInfo **procinfo, Oid *collation, bool *closerSet, HnswCandidate * newCandidate, HnswCandidate * *pruned, bool sortCandidates)
{
List *r = NIL;
List *w = list_copy(c);
@@ -992,7 +1178,7 @@ SelectNeighbors(char *base, List *c, int lm, FmgrInfo *procinfo, Oid collation,
/* Use previous state of r and wd to skip work when possible */
if (mustCalculate)
e->closer = CheckElementCloser(base, e, r, procinfo, collation);
e->closer = CheckElementCloser(base, e, r, index, procinfo, collation);
else if (list_length(added) > 0)
{
/* Keep Valgrind happy for in-memory, parallel builds */
@@ -1005,7 +1191,7 @@ SelectNeighbors(char *base, List *c, int lm, FmgrInfo *procinfo, Oid collation,
*/
if (e->closer)
{
e->closer = CheckElementCloser(base, e, added, procinfo, collation);
e->closer = CheckElementCloser(base, e, added, index, procinfo, collation);
if (!e->closer)
removedAny = true;
@@ -1018,7 +1204,7 @@ SelectNeighbors(char *base, List *c, int lm, FmgrInfo *procinfo, Oid collation,
*/
if (removedAny)
{
e->closer = CheckElementCloser(base, e, r, procinfo, collation);
e->closer = CheckElementCloser(base, e, r, index, procinfo, collation);
if (e->closer)
added = lappend(added, e);
}
@@ -1026,7 +1212,7 @@ SelectNeighbors(char *base, List *c, int lm, FmgrInfo *procinfo, Oid collation,
}
else if (e == newCandidate)
{
e->closer = CheckElementCloser(base, e, r, procinfo, collation);
e->closer = CheckElementCloser(base, e, r, index, procinfo, collation);
if (e->closer)
added = lappend(added, e);
}
@@ -1077,7 +1263,7 @@ AddConnections(char *base, HnswElement element, List *neighbors, int lc)
* Update connections
*/
void
HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newElement, float distance, int lm, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation)
HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newElement, float distance, int lm, int *updateIdx, Relation index, FmgrInfo **procinfo, Oid *collation)
{
HnswCandidate newHc;
@@ -1103,7 +1289,7 @@ HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newE
c = lappend(c, &neighbors->items[i]);
c = lappend(c, &newHc);
SelectNeighbors(base, c, lm, procinfo, collation, &neighbors->closerSet, &newHc, &pruned, true);
SelectNeighbors(base, c, lm, index, procinfo, collation, &neighbors->closerSet, &newHc, &pruned, true);
/* Should not happen */
if (pruned == NULL)
@@ -1174,17 +1360,19 @@ PrecomputeHash(char *base, HnswElement element)
* Algorithm 1 from paper
*/
void
HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, bool existing)
HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo **procinfo, Oid *collation, int m, int efConstruction, bool existing, bool inMemory)
{
List *ep;
List *w;
int level = element->level;
int entryLevel;
Datum q = HnswGetValue(base, element);
IndexTuple qtup = HnswPtrAccess(base, element->itup);
ScanKeyData *keyData = NULL;
HnswElement skipElement = existing ? element : NULL;
/* Precompute hash */
if (index == NULL)
if (inMemory)
PrecomputeHash(base, element);
/* No neighbors if no entry point */
@@ -1192,13 +1380,13 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint
return;
/* Get entry point and level */
ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, procinfo, collation, true));
ep = list_make1(HnswEntryCandidate(base, entryPoint, q, qtup, keyData, index, procinfo, collation, true, inMemory));
entryLevel = entryPoint->level;
/* 1st phase: greedy search to insert level */
for (int lc = entryLevel; lc >= level + 1; lc--)
{
w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement);
w = HnswSearchLayer(base, q, qtup, keyData, ep, 1, lc, index, procinfo, collation, m, true, skipElement, inMemory);
ep = w;
}
@@ -1217,7 +1405,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint
List *lw = NIL;
ListCell *lc2;
w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement);
w = HnswSearchLayer(base, q, qtup, keyData, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement, inMemory);
/* Convert search candidates to candidates */
foreach(lc2, w)
@@ -1233,7 +1421,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint
/* Elements being deleted or skipped can help with search */
/* but should be removed before selecting neighbors */
if (index != NULL)
if (!inMemory)
lw = RemoveElements(base, lw, skipElement);
/*
@@ -1241,7 +1429,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint
* sortCandidates to true for in-memory builds to enable closer
* caching, but there does not seem to be a difference in performance.
*/
neighbors = SelectNeighbors(base, lw, lm, procinfo, collation, &HnswGetNeighbors(base, element, lc)->closerSet, NULL, NULL, false);
neighbors = SelectNeighbors(base, lw, lm, index, procinfo, collation, &HnswGetNeighbors(base, element, lc)->closerSet, NULL, NULL, false);
AddConnections(base, element, neighbors, lc);

View File

@@ -189,8 +189,8 @@ RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element, HnswEleme
GenericXLogState *state;
int m = vacuumstate->m;
int efConstruction = vacuumstate->efConstruction;
FmgrInfo *procinfo = vacuumstate->procinfo;
Oid collation = vacuumstate->collation;
FmgrInfo **procinfo = vacuumstate->procinfo;
Oid *collation = vacuumstate->collation;
BufferAccessStrategy bas = vacuumstate->bas;
HnswNeighborTuple ntup = vacuumstate->ntup;
Size ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, m);
@@ -205,7 +205,7 @@ RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element, HnswEleme
element->heaptidsLength = 0;
/* Find neighbors for element, skipping itself */
HnswFindElementNeighbors(base, element, entryPoint, index, procinfo, collation, m, efConstruction, true);
HnswFindElementNeighbors(base, element, entryPoint, index, procinfo, collation, m, efConstruction, true, false);
/* Zero memory for each element */
MemSet(ntup, 0, HNSW_TUPLE_ALLOC_SIZE);
@@ -256,7 +256,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate)
LockPage(index, HNSW_UPDATE_LOCK, ShareLock);
/* Load element */
HnswLoadElement(highestPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true, NULL);
HnswLoadElement(highestPoint, NULL, NULL, NULL, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true, NULL);
/* Repair if needed */
if (NeedsUpdated(vacuumstate, highestPoint))
@@ -294,7 +294,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate)
* is outdated, this can remove connections at higher levels in
* the graph until they are repaired, but this should be fine.
*/
HnswLoadElement(entryPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true, NULL);
HnswLoadElement(entryPoint, NULL, NULL, NULL, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true, NULL);
if (NeedsUpdated(vacuumstate, entryPoint))
{
@@ -370,7 +370,7 @@ RepairGraph(HnswVacuumState * vacuumstate)
/* Create an element */
element = HnswInitElementFromBlock(blkno, offno);
HnswLoadElementFromTuple(element, etup, false, true);
HnswLoadElementFromTuple(element, etup, false, true, index);
elements = lappend(elements, element);
}
@@ -440,6 +440,7 @@ MarkDeleted(HnswVacuumState * vacuumstate)
BlockNumber insertPage = InvalidBlockNumber;
Relation index = vacuumstate->index;
BufferAccessStrategy bas = vacuumstate->bas;
bool useIndexTuple = HnswUseIndexTuple(index);
/*
* Wait for index scans to complete. Scans before this point may contain
@@ -521,7 +522,14 @@ MarkDeleted(HnswVacuumState * vacuumstate)
/* Overwrite element */
etup->deleted = 1;
MemSet(&etup->data, 0, VARSIZE_ANY(&etup->data));
if (useIndexTuple)
{
IndexTuple itup = (IndexTuple) &etup->data;
MemSet(itup, 0, IndexTupleSize(itup));
}
else
MemSet(&etup->data, 0, VARSIZE_ANY(&etup->data));
/* Overwrite neighbors */
for (int i = 0; i < ntup->count; i++)
@@ -573,8 +581,8 @@ InitVacuumState(HnswVacuumState * vacuumstate, IndexVacuumInfo *info, IndexBulkD
vacuumstate->callback_state = callback_state;
vacuumstate->efConstruction = HnswGetEfConstruction(index);
vacuumstate->bas = GetAccessStrategy(BAS_BULKREAD);
vacuumstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
vacuumstate->collation = index->rd_indcollation[0];
HnswInitProcinfo(vacuumstate->procinfo, index);
vacuumstate->collation = index->rd_indcollation;
vacuumstate->ntup = palloc0(HNSW_TUPLE_ALLOC_SIZE);
vacuumstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext,
"Hnsw vacuum temporary context",

View File

@@ -0,0 +1,109 @@
use strict;
use warnings FATAL => 'all';
use PostgreSQL::Test::Cluster;
use PostgreSQL::Test::Utils;
use Test::More;
my $node;
my @queries = ();
my @cs = ();
my @expected;
my $limit = 20;
my $dim = 3;
my $array_sql = join(",", ('random()') x $dim);
my $nc = 50;
sub test_recall
{
my ($min, $operator) = @_;
my $correct = 0;
my $total = 0;
my $explain = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
EXPLAIN ANALYZE SELECT i FROM tst WHERE c = $cs[0] ORDER BY v $operator '$queries[0]' LIMIT $limit;
));
like($explain, qr/Index Cond/);
for my $i (0 .. $#queries)
{
my $actual = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
SELECT i FROM tst WHERE c = $cs[$i] ORDER BY v $operator '$queries[$i]' LIMIT $limit;
));
my @actual_ids = split("\n", $actual);
my %actual_set = map { $_ => 1 } @actual_ids;
is(scalar(@actual_ids), $limit);
my @expected_ids = split("\n", $expected[$i]);
foreach (@expected_ids)
{
if (exists($actual_set{$_}))
{
$correct++;
}
$total++;
}
}
cmp_ok($correct / $total, ">=", $min, $operator);
}
# Initialize node
$node = PostgreSQL::Test::Cluster->new('node');
$node->init;
$node->start;
# Create table
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim), c int4);");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT i, ARRAY[$array_sql], i % $nc FROM generate_series(1, 20000) i;"
);
# Generate queries
for (1 .. 20)
{
my @r = ();
for (1 .. $dim)
{
push(@r, rand());
}
push(@queries, "[" . join(",", @r) . "]");
push(@cs, int(rand() * $nc));
}
# Get exact results
@expected = ();
for my $i (0 .. $#queries)
{
my $res = $node->safe_psql("postgres", "SELECT i FROM tst WHERE c = $cs[$i] ORDER BY v <-> '$queries[$i]' LIMIT $limit;");
push(@expected, $res);
}
# Add index
$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops, c);");
# Test recall
test_recall(0.99, '<->');
# Test vacuum
$node->safe_psql("postgres", "DELETE FROM tst WHERE c > 5;");
$node->safe_psql("postgres", "VACUUM tst;");
# Test columns
my ($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (c);");
like($stderr, qr/first column must be a vector/);
($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (c, v vector_l2_ops);");
like($stderr, qr/first column must be a vector/);
($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops, c, c);");
like($stderr, qr/index cannot have more than two columns/);
($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops, v vector_l2_ops);");
like($stderr, qr/column 2 cannot be a vector/);
done_testing();