diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a4682e7..3b81560 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -49,7 +49,7 @@ jobs: - postgres: 16 os: macos-14 - postgres: 14 - os: macos-12 + os: macos-13 steps: - uses: actions/checkout@v4 - uses: ankane/setup-postgres@v1 diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c2e827..008bba1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,10 @@ ## 0.8.0 (unreleased) +- Added support for iterative index scans - Added `intvec` type - Added casts for arrays to `sparsevec` +- Improved cost estimation +- Improved performance of HNSW inserts and on-disk index builds - Reduced memory usage for HNSW index scans - Dropped support for Postgres 12 diff --git a/Dockerfile b/Dockerfile index eef0c8f..9364409 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -ARG PG_MAJOR=16 +ARG PG_MAJOR=17 FROM postgres:$PG_MAJOR ARG PG_MAJOR diff --git a/Makefile b/Makefile index b68263d..987b1db 100644 --- a/Makefile +++ b/Makefile @@ -66,7 +66,7 @@ dist: git archive --format zip --prefix=$(EXTENSION)-$(EXTVERSION)/ --output dist/$(EXTENSION)-$(EXTVERSION).zip master # for Docker -PG_MAJOR ?= 16 +PG_MAJOR ?= 17 .PHONY: docker diff --git a/README.md b/README.md index 2314bb1..4ea05a3 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,8 @@ nmake /F Makefile.win nmake /F Makefile.win install ``` +Note: Postgres 17 is not supported yet due to an upstream issue + See the [installation notes](#installation-notes---windows) if you run into issues You can also install it with [Docker](#docker) or [conda-forge](#conda-forge). @@ -100,6 +102,8 @@ Or add a vector column to an existing table ALTER TABLE items ADD COLUMN embedding vector(3); ``` +Also supports [half-precision](#half-precision-vectors), [binary](#binary-vectors), and [sparse](#sparse-vectors) vectors + Insert vectors ```sql @@ -145,6 +149,8 @@ Supported distance functions are: - `<#>` - (negative) inner product - `<=>` - cosine distance - `<+>` - L1 distance (added in 0.7.0) +- `<~>` - Hamming distance (binary vectors, added in 0.7.0) +- `<%>` - Jaccard distance (binary vectors, added in 0.7.0) Get the nearest neighbors to a row @@ -1018,7 +1024,7 @@ l2_normalize(sparsevec) → sparsevec | Normalize with Euclidean norm | 0.7.0 If your machine has multiple Postgres installations, specify the path to [pg_config](https://www.postgresql.org/docs/current/app-pgconfig.html) with: ```sh -export PG_CONFIG=/Library/PostgreSQL/16/bin/pg_config +export PG_CONFIG=/Library/PostgreSQL/17/bin/pg_config ``` Then re-run the installation instructions (run `make clean` before `make` if needed). If `sudo` is needed for `make install`, use: @@ -1029,11 +1035,11 @@ sudo --preserve-env=PG_CONFIG make install A few common paths on Mac are: -- EDB installer - `/Library/PostgreSQL/16/bin/pg_config` -- Homebrew (arm64) - `/opt/homebrew/opt/postgresql@16/bin/pg_config` -- Homebrew (x86-64) - `/usr/local/opt/postgresql@16/bin/pg_config` +- EDB installer - `/Library/PostgreSQL/17/bin/pg_config` +- Homebrew (arm64) - `/opt/homebrew/opt/postgresql@17/bin/pg_config` +- Homebrew (x86-64) - `/usr/local/opt/postgresql@17/bin/pg_config` -Note: Replace `16` with your Postgres server version +Note: Replace `17` with your Postgres server version ### Missing Header @@ -1042,10 +1048,10 @@ If compilation fails with `fatal error: postgres.h: No such file or directory`, For Ubuntu and Debian, use: ```sh -sudo apt install postgresql-server-dev-16 +sudo apt install postgresql-server-dev-17 ``` -Note: Replace `16` with your Postgres server version +Note: Replace `17` with your Postgres server version ### Missing SDK @@ -1078,17 +1084,17 @@ If installation fails with `Access is denied`, re-run the installation instructi Get the [Docker image](https://hub.docker.com/r/pgvector/pgvector) with: ```sh -docker pull pgvector/pgvector:pg16 +docker pull pgvector/pgvector:pg17 ``` -This adds pgvector to the [Postgres image](https://hub.docker.com/_/postgres) (replace `16` with your Postgres server version, and run it the same way). +This adds pgvector to the [Postgres image](https://hub.docker.com/_/postgres) (replace `17` with your Postgres server version, and run it the same way). You can also build the image manually: ```sh git clone --branch v0.7.4 https://github.com/pgvector/pgvector.git cd pgvector -docker build --pull --build-arg PG_MAJOR=16 -t myuser/pgvector . +docker build --pull --build-arg PG_MAJOR=17 -t myuser/pgvector . ``` ### Homebrew @@ -1099,7 +1105,7 @@ With Homebrew Postgres, you can use: brew install pgvector ``` -Note: This only adds it to the `postgresql@14` formula +Note: This only adds it to the `postgresql@17` and `postgresql@14` formulas ### PGXN @@ -1114,22 +1120,22 @@ pgxn install vector Debian and Ubuntu packages are available from the [PostgreSQL APT Repository](https://wiki.postgresql.org/wiki/Apt). Follow the [setup instructions](https://wiki.postgresql.org/wiki/Apt#Quickstart) and run: ```sh -sudo apt install postgresql-16-pgvector +sudo apt install postgresql-17-pgvector ``` -Note: Replace `16` with your Postgres server version +Note: Replace `17` with your Postgres server version ### Yum RPM packages are available from the [PostgreSQL Yum Repository](https://yum.postgresql.org/). Follow the [setup instructions](https://www.postgresql.org/download/linux/redhat/) for your distribution and run: ```sh -sudo yum install pgvector_16 +sudo yum install pgvector_17 # or -sudo dnf install pgvector_16 +sudo dnf install pgvector_17 ``` -Note: Replace `16` with your Postgres server version +Note: Replace `17` with your Postgres server version ### pkg diff --git a/src/halfvec.c b/src/halfvec.c index 9cd3de6..aad320b 100644 --- a/src/halfvec.c +++ b/src/halfvec.c @@ -159,24 +159,6 @@ CheckStateArray(ArrayType *statearray, const char *caller) return (float8 *) ARR_DATA_PTR(statearray); } -#if PG_VERSION_NUM < 120003 -static pg_noinline void -float_overflow_error(void) -{ - ereport(ERROR, - (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), - errmsg("value out of range: overflow"))); -} - -static pg_noinline void -float_underflow_error(void) -{ - ereport(ERROR, - (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), - errmsg("value out of range: underflow"))); -} -#endif - /* * Convert textual representation to internal representation */ diff --git a/src/hnsw.c b/src/hnsw.c index a7b1e5f..5c747b5 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -12,12 +12,22 @@ #include "utils/float.h" #include "utils/guc.h" #include "utils/selfuncs.h" +#include "utils/spccache.h" #if PG_VERSION_NUM < 150000 #define MarkGUCPrefixReserved(x) EmitWarningsOnPlaceholders(x) #endif +static const struct config_enum_entry hnsw_iterative_search_options[] = { + {"off", HNSW_ITERATIVE_SEARCH_OFF, false}, + {"relaxed_order", HNSW_ITERATIVE_SEARCH_RELAXED, false}, + {"strict_order", HNSW_ITERATIVE_SEARCH_STRICT, false}, + {NULL, 0, false} +}; + int hnsw_ef_search; +int hnsw_max_search_tuples; +int hnsw_iterative_search; int hnsw_lock_tranche_id; static relopt_kind hnsw_relopt_kind; @@ -68,6 +78,15 @@ HnswInit(void) "Valid range is 1..1000.", &hnsw_ef_search, HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL); + DefineCustomEnumVariable("hnsw.iterative_search", "Sets the iterative search mode", + NULL, &hnsw_iterative_search, + HNSW_ITERATIVE_SEARCH_OFF, hnsw_iterative_search_options, PGC_USERSET, 0, NULL, NULL, NULL); + + /* This is approximate and does not apply to the initial scan */ + DefineCustomIntVariable("hnsw.max_search_tuples", "Sets the max number of candidates to visit for iterative search", + "-1 means no limit", &hnsw_max_search_tuples, + -1, -1, INT_MAX, PGC_USERSET, 0, NULL, NULL, NULL); + MarkGUCPrefixReserved("hnsw"); } @@ -99,7 +118,9 @@ hnswcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, { GenericCosts costs; int m; - int entryLevel; + double ratio; + double startupPages; + double spc_seq_page_cost; Relation index; /* Never use index without order */ @@ -115,21 +136,71 @@ hnswcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, MemSet(&costs, 0, sizeof(costs)); + genericcostestimate(root, path, loop_count, &costs); + index = index_open(path->indexinfo->indexoid, NoLock); HnswGetMetaPageInfo(index, &m, NULL); index_close(index, NoLock); - /* Approximate entry level */ - entryLevel = (int) -log(1.0 / path->indexinfo->tuples) * HnswGetMl(m); + /* + * HNSW cost estimation follows a formula that accounts for the total + * number of tuples indexed combined with the parameters that most + * influence the duration of the index scan, namely: m - the number of + * tuples that are scanned in each step of the HNSW graph traversal + * ef_search - which influences the total number of steps taken at layer 0 + * + * The source of the vector data can impact how many steps it takes to + * converge on the set of vectors to return to the executor. Currently, we + * use a hardcoded scaling factor (HNSWScanScalingFactor) to help + * influence that, but this could later become a configurable parameter + * based on the cost estimations. + * + * The tuple estimator formula is below: + * + * numIndexTuples = entryLevel * m + layer0TuplesMax * layer0Selectivity + * + * "entryLevel * m" represents the floor of tuples we need to scan to get + * to layer 0 (L0). + * + * "layer0TuplesMax" is the estimated total number of tuples we'd scan at + * L0 if we weren't discarding already visited tuples as part of the scan. + * + * "layer0Selectivity" estimates the percentage of tuples that are scanned + * at L0, accounting for previously visited tuples, multiplied by the + * "scalingFactor" (currently hardcoded). + */ + if (path->indexinfo->tuples > 0) + { + double scalingFactor = 0.55; + int entryLevel = (int) (log(path->indexinfo->tuples) * HnswGetMl(m)); + int layer0TuplesMax = HnswGetLayerM(m, 0) * hnsw_ef_search; + double layer0Selectivity = scalingFactor * log(path->indexinfo->tuples) / (log(m) * (1 + log(hnsw_ef_search))); - /* TODO Improve estimate of visited tuples (currently underestimates) */ - /* Account for number of tuples (or entry level), m, and ef_search */ - costs.numIndexTuples = (entryLevel + 2) * m; + ratio = (entryLevel * m + layer0TuplesMax * layer0Selectivity) / path->indexinfo->tuples; - genericcostestimate(root, path, loop_count, &costs); + if (ratio > 1) + ratio = 1; + } + else + ratio = 1; - /* Use total cost since most work happens before first tuple is returned */ - *indexStartupCost = costs.indexTotalCost; + get_tablespace_page_costs(path->indexinfo->reltablespace, NULL, &spc_seq_page_cost); + + /* Startup cost is cost before returning the first row */ + costs.indexStartupCost = costs.indexTotalCost * ratio; + + /* Adjust cost if needed since TOAST not included in seq scan cost */ + startupPages = costs.numIndexPages * ratio; + if (startupPages > path->indexinfo->rel->pages && ratio < 0.5) + { + /* Change all page cost from random to sequential */ + costs.indexStartupCost -= startupPages * (costs.spc_random_page_cost - spc_seq_page_cost); + + /* Remove cost of extra pages */ + costs.indexStartupCost -= (startupPages - path->indexinfo->rel->pages) * spc_seq_page_cost; + } + + *indexStartupCost = costs.indexStartupCost; *indexTotalCost = costs.indexTotalCost; *indexSelectivity = costs.indexSelectivity; *indexCorrelation = costs.indexCorrelation; diff --git a/src/hnsw.h b/src/hnsw.h index 9fb650a..6b184ec 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -88,6 +88,9 @@ /* Ensure fits on page and in uint8 */ #define HnswGetMaxLevel(m) Min(((BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(HnswPageOpaqueData)) - offsetof(HnswNeighborTupleData, indextids) - sizeof(ItemIdData)) / (sizeof(ItemPointerData)) / (m)) - 2, 255) +#define HnswGetSearchCandidate(membername, ptr) pairingheap_container(HnswSearchCandidate, membername, ptr) +#define HnswGetSearchCandidateConst(membername, ptr) pairingheap_const_container(HnswSearchCandidate, membername, ptr) + #define HnswGetValue(base, element) PointerGetDatum(HnswPtrAccess(base, (element)->value)) #if PG_VERSION_NUM < 140005 @@ -106,8 +109,17 @@ /* Variables */ extern int hnsw_ef_search; +extern int hnsw_iterative_search; +extern int hnsw_max_search_tuples; extern int hnsw_lock_tranche_id; +typedef enum HnswIterativeSearchMode +{ + HNSW_ITERATIVE_SEARCH_OFF, + HNSW_ITERATIVE_SEARCH_RELAXED, + HNSW_ITERATIVE_SEARCH_STRICT +} HnswIterativeSearchMode; + typedef struct HnswElementData HnswElementData; typedef struct HnswNeighborArray HnswNeighborArray; @@ -129,6 +141,7 @@ struct HnswElementData uint8 heaptidsLength; uint8 level; uint8 deleted; + uint8 version; uint32 hash; HnswNeighborsPtr neighbors; BlockNumber blkno; @@ -160,7 +173,7 @@ typedef struct HnswSearchCandidate pairingheap_node c_node; pairingheap_node w_node; HnswElementPtr element; - float distance; + double distance; } HnswSearchCandidate; /* HNSW index options */ @@ -185,8 +198,8 @@ typedef struct HnswGraph /* Allocations state */ LWLock allocatorLock; - long memoryUsed; - long memoryTotal; + Size memoryUsed; + Size memoryTotal; /* Flushed state */ LWLock flushLock; @@ -237,6 +250,18 @@ typedef struct HnswTypeInfo void (*checkValue) (Pointer v); } HnswTypeInfo; +typedef struct HnswSupport +{ + FmgrInfo *procinfo; + FmgrInfo *normprocinfo; + Oid collation; +} HnswSupport; + +typedef struct HnswQuery +{ + Datum value; +} HnswQuery; + typedef struct HnswBuildState { /* Info */ @@ -256,9 +281,7 @@ typedef struct HnswBuildState double reltuples; /* Support functions */ - FmgrInfo *procinfo; - FmgrInfo *normprocinfo; - Oid collation; + HnswSupport support; /* Variables */ HnswGraph graphData; @@ -306,10 +329,10 @@ typedef struct HnswElementTupleData uint8 type; uint8 level; uint8 deleted; - uint8 unused; + uint8 version; ItemPointerData heaptids[HNSW_HEAPTIDS]; ItemPointerData neighbortid; - uint16 unused2; + uint16 unused; Vector data; } HnswElementTupleData; @@ -318,24 +341,41 @@ typedef HnswElementTupleData * HnswElementTuple; typedef struct HnswNeighborTupleData { uint8 type; - uint8 unused; + uint8 version; uint16 count; ItemPointerData indextids[FLEXIBLE_ARRAY_MEMBER]; } HnswNeighborTupleData; typedef HnswNeighborTupleData * HnswNeighborTuple; +typedef union +{ + struct pointerhash_hash *pointers; + struct offsethash_hash *offsets; + struct tidhash_hash *tids; +} visited_hash; + +typedef union +{ + HnswElement element; + ItemPointerData indextid; +} HnswUnvisited; + typedef struct HnswScanOpaqueData { const HnswTypeInfo *typeInfo; bool first; List *w; + visited_hash v; + pairingheap *discarded; + HnswQuery q; + int m; + int64 tuples; + double previousDistance; MemoryContext tmpCtx; /* Support functions */ - FmgrInfo *procinfo; - FmgrInfo *normprocinfo; - Oid collation; + HnswSupport support; } HnswScanOpaqueData; typedef HnswScanOpaqueData * HnswScanOpaque; @@ -353,8 +393,7 @@ typedef struct HnswVacuumState int efConstruction; /* Support functions */ - FmgrInfo *procinfo; - Oid collation; + HnswSupport support; /* Variables */ struct tidhash_hash *deleted; @@ -370,30 +409,33 @@ typedef struct HnswVacuumState int HnswGetM(Relation index); int HnswGetEfConstruction(Relation index); FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum); +void HnswInitSupport(HnswSupport * support, Relation index); Datum HnswNormValue(const HnswTypeInfo * typeInfo, Oid collation, Datum value); -bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); +bool HnswCheckNorm(HnswSupport * support, Datum value); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); -List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement); +List *HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited, int64 *tuples); HnswElement HnswGetEntryPoint(Relation index); void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint); void *HnswAlloc(HnswAllocator * allocator, Size size); HnswElement HnswInitElement(char *base, ItemPointer tid, int m, double ml, int maxLevel, HnswAllocator * alloc); HnswElement HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno); -void HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, bool existing); -HnswSearchCandidate *HnswEntryCandidate(char *base, HnswElement em, Datum q, Relation rel, FmgrInfo *procinfo, Oid collation, bool loadVec); +void HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, HnswSupport * support, int m, int efConstruction, bool existing); +HnswSearchCandidate *HnswEntryCandidate(char *base, HnswElement em, HnswQuery * q, Relation rel, HnswSupport * support, bool loadVec); void HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum, bool building); void HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m); void HnswAddHeapTid(HnswElement element, ItemPointer heaptid); +HnswNeighborArray *HnswInitNeighborArray(int lm, HnswAllocator * allocator); void HnswInitNeighbors(char *base, HnswElement element, int m, HnswAllocator * alloc); -bool HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, ItemPointer heap_tid, bool building); -void HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement e, int m, bool checkExisting, bool building); +bool HnswInsertTupleOnDisk(Relation index, HnswSupport * support, Datum value, ItemPointer heaptid, bool building); +void HnswUpdateNeighborsOnDisk(Relation index, HnswSupport * support, HnswElement e, int m, bool checkExisting, bool building); void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec); -void HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, float *maxDistance); +void HnswLoadElement(HnswElement element, double *distance, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance); +bool HnswFormIndexValue(Datum *out, Datum *values, bool *isnull, const HnswTypeInfo * typeInfo, HnswSupport * support); void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element); -void HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm, int lc, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation); -void HnswLoadNeighbors(HnswElement element, Relation index, int m); +void HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newElement, float distance, int lm, int *updateIdx, Relation index, HnswSupport * support); +bool HnswLoadNeighborTids(HnswElement element, ItemPointerData *indextids, Relation index, int m, int lm, int lc); void HnswInitLockTranche(void); const HnswTypeInfo *HnswGetTypeInfo(Relation index); PGDLLEXPORT void HnswParallelBuildMain(dsm_segment *seg, shm_toc *toc); diff --git a/src/hnswbuild.c b/src/hnswbuild.c index 498b5d9..b667478 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -366,7 +366,7 @@ AddElementInMemory(char *base, HnswGraph * graph, HnswElement element) * Update neighbors */ static void -UpdateNeighborsInMemory(char *base, FmgrInfo *procinfo, Oid collation, HnswElement e, int m) +UpdateNeighborsInMemory(char *base, HnswSupport * support, HnswElement e, int m) { for (int lc = e->level; lc >= 0; lc--) { @@ -388,7 +388,7 @@ UpdateNeighborsInMemory(char *base, FmgrInfo *procinfo, Oid collation, HnswEleme Assert(neighborElement); LWLockAcquire(&neighborElement->lock, LW_EXCLUSIVE); - HnswUpdateConnection(base, e, hc, lm, lc, NULL, NULL, procinfo, collation); + HnswUpdateConnection(base, HnswGetNeighbors(base, neighborElement, lc), e, hc->distance, lm, NULL, NULL, support); LWLockRelease(&neighborElement->lock); } } @@ -398,7 +398,7 @@ UpdateNeighborsInMemory(char *base, FmgrInfo *procinfo, Oid collation, HnswEleme * Update graph in memory */ static void -UpdateGraphInMemory(FmgrInfo *procinfo, Oid collation, HnswElement element, int m, int efConstruction, HnswElement entryPoint, HnswBuildState * buildstate) +UpdateGraphInMemory(HnswSupport * support, HnswElement element, int m, int efConstruction, HnswElement entryPoint, HnswBuildState * buildstate) { HnswGraph *graph = buildstate->graph; char *base = buildstate->hnswarea; @@ -411,7 +411,7 @@ UpdateGraphInMemory(FmgrInfo *procinfo, Oid collation, HnswElement element, int AddElementInMemory(base, graph, element); /* Update neighbors */ - UpdateNeighborsInMemory(base, procinfo, collation, element, m); + UpdateNeighborsInMemory(base, support, element, m); /* Update entry point if needed (already have lock) */ if (entryPoint == NULL || element->level > entryPoint->level) @@ -424,9 +424,8 @@ UpdateGraphInMemory(FmgrInfo *procinfo, Oid collation, HnswElement element, int static void InsertTupleInMemory(HnswBuildState * buildstate, HnswElement element) { - FmgrInfo *procinfo = buildstate->procinfo; - Oid collation = buildstate->collation; HnswGraph *graph = buildstate->graph; + HnswSupport *support = &buildstate->support; HnswElement entryPoint; LWLock *entryLock = &graph->entryLock; LWLock *entryWaitLock = &graph->entryWaitLock; @@ -458,10 +457,10 @@ InsertTupleInMemory(HnswBuildState * buildstate, HnswElement element) } /* Find neighbors for element */ - HnswFindElementNeighbors(base, element, entryPoint, NULL, procinfo, collation, m, efConstruction, false); + HnswFindElementNeighbors(base, element, entryPoint, NULL, support, m, efConstruction, false); /* Update graph in memory */ - UpdateGraphInMemory(procinfo, collation, element, m, efConstruction, entryPoint, buildstate); + UpdateGraphInMemory(support, element, m, efConstruction, entryPoint, buildstate); /* Release entry lock */ LWLockRelease(entryLock); @@ -473,30 +472,19 @@ InsertTupleInMemory(HnswBuildState * buildstate, HnswElement element) static bool InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, HnswBuildState * buildstate) { - const HnswTypeInfo *typeInfo = buildstate->typeInfo; HnswGraph *graph = buildstate->graph; HnswElement element; HnswAllocator *allocator = &buildstate->allocator; + HnswSupport *support = &buildstate->support; Size valueSize; Pointer valuePtr; LWLock *flushLock = &graph->flushLock; char *base = buildstate->hnswarea; + Datum value; - /* Detoast once for all calls */ - Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); - - /* Check value */ - if (typeInfo->checkValue != NULL) - typeInfo->checkValue(DatumGetPointer(value)); - - /* Normalize if needed */ - if (buildstate->normprocinfo != NULL) - { - if (!HnswCheckNorm(buildstate->normprocinfo, buildstate->collation, value)) - return false; - - value = HnswNormValue(typeInfo, buildstate->collation, value); - } + /* Form index value */ + if (!HnswFormIndexValue(&value, values, isnull, buildstate->typeInfo, support)) + return false; /* Get datum size */ valueSize = VARSIZE_ANY(DatumGetPointer(value)); @@ -509,7 +497,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn { LWLockRelease(flushLock); - return HnswInsertTupleOnDisk(index, value, values, isnull, heaptid, true); + return HnswInsertTupleOnDisk(index, support, value, heaptid, true); } /* @@ -541,7 +529,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn LWLockRelease(flushLock); - return HnswInsertTupleOnDisk(index, value, values, isnull, heaptid, true); + return HnswInsertTupleOnDisk(index, support, value, heaptid, true); } /* Ok, we can proceed to allocate the element */ @@ -607,7 +595,7 @@ BuildCallback(Relation index, ItemPointer tid, Datum *values, * Initialize the graph */ static void -InitGraph(HnswGraph * graph, char *base, long memoryTotal) +InitGraph(HnswGraph * graph, char *base, Size memoryTotal) { /* Initialize the lock tranche if needed */ HnswInitLockTranche(); @@ -704,11 +692,9 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->indtuples = 0; /* Get support functions */ - buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); - buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); - buildstate->collation = index->rd_indcollation[0]; + HnswInitSupport(&buildstate->support, index); - InitGraph(&buildstate->graphData, NULL, maintenance_work_mem * 1024L); + InitGraph(&buildstate->graphData, NULL, (Size) maintenance_work_mem * 1024L); buildstate->graph = &buildstate->graphData; buildstate->ml = HnswGetMl(buildstate->m); buildstate->maxLevel = HnswGetMaxLevel(buildstate->m); diff --git a/src/hnswinsert.c b/src/hnswinsert.c index 2dce16f..a5fac4e 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -36,7 +36,7 @@ GetInsertPage(Relation index) * Check for a free offset */ static bool -HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size etupSize, Size ntupSize, Buffer *nbuf, Page *npage, OffsetNumber *freeOffno, OffsetNumber *freeNeighborOffno, BlockNumber *newInsertPage) +HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size etupSize, Size ntupSize, Buffer *nbuf, Page *npage, OffsetNumber *freeOffno, OffsetNumber *freeNeighborOffno, BlockNumber *newInsertPage, uint8 *tupleVersion) { OffsetNumber offno; OffsetNumber maxoffno = PageGetMaxOffsetNumber(page); @@ -98,6 +98,7 @@ HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size { *freeOffno = offno; *freeNeighborOffno = neighborOffno; + *tupleVersion = etup->version; return true; } else if (*nbuf != buf) @@ -153,6 +154,7 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B OffsetNumber freeOffno = InvalidOffsetNumber; OffsetNumber freeNeighborOffno = InvalidOffsetNumber; BlockNumber newInsertPage = InvalidBlockNumber; + uint8 tupleVersion; char *base = NULL; /* Calculate sizes */ @@ -202,7 +204,7 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B } /* Next, try space from a deleted element */ - if (HnswFreeOffset(index, buf, page, e, etupSize, ntupSize, &nbuf, &npage, &freeOffno, &freeNeighborOffno, &newInsertPage)) + if (HnswFreeOffset(index, buf, page, e, etupSize, ntupSize, &nbuf, &npage, &freeOffno, &freeNeighborOffno, &newInsertPage, &tupleVersion)) { if (nbuf != buf) { @@ -212,6 +214,10 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B npage = GenericXLogRegisterBuffer(state, nbuf, 0); } + /* Set tuple version */ + etup->version = tupleVersion; + ntup->version = tupleVersion; + break; } @@ -334,6 +340,107 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B *updatedInsertPage = newInsertPage; } +/* + * Load neighbors + */ +static HnswNeighborArray * +HnswLoadNeighbors(HnswElement element, Relation index, int m, int lm, int lc) +{ + char *base = NULL; + HnswNeighborArray *neighbors = HnswInitNeighborArray(lm, NULL); + ItemPointerData indextids[HNSW_MAX_M * 2]; + + if (!HnswLoadNeighborTids(element, indextids, index, m, lm, lc)) + return neighbors; + + for (int i = 0; i < lm; i++) + { + ItemPointer indextid = &indextids[i]; + HnswElement e; + HnswCandidate *hc; + + if (!ItemPointerIsValid(indextid)) + break; + + e = HnswInitElementFromBlock(ItemPointerGetBlockNumber(indextid), ItemPointerGetOffsetNumber(indextid)); + hc = &neighbors->items[neighbors->length++]; + HnswPtrStore(base, hc->element, e); + } + + return neighbors; +} + +/* + * Load elements for insert + */ +static void +LoadElementsForInsert(HnswNeighborArray * neighbors, HnswQuery * q, int *idx, Relation index, HnswSupport * support) +{ + char *base = NULL; + + for (int i = 0; i < neighbors->length; i++) + { + HnswCandidate *hc = &neighbors->items[i]; + HnswElement element = HnswPtrAccess(base, hc->element); + double distance; + + HnswLoadElement(element, &distance, q, index, support, true, NULL); + hc->distance = distance; + + /* Prune element if being deleted */ + if (element->heaptidsLength == 0) + { + *idx = i; + break; + } + } +} + +/* + * Get update index + */ +static int +GetUpdateIndex(HnswElement element, HnswElement newElement, float distance, int m, int lm, int lc, Relation index, HnswSupport * support, MemoryContext updateCtx) +{ + char *base = NULL; + int idx = -1; + HnswNeighborArray *neighbors; + MemoryContext oldCtx = MemoryContextSwitchTo(updateCtx); + + /* + * Get latest neighbors since they may have changed. Do not lock yet since + * selecting neighbors can take time. Could use optimistic locking to + * retry if another update occurs before getting exclusive lock. + */ + neighbors = HnswLoadNeighbors(element, index, m, lm, lc); + + /* + * Could improve performance for vacuuming by checking neighbors against + * list of elements being deleted to find index. It's important to exclude + * already deleted elements for this since they can be replaced at any + * time. + */ + + if (neighbors->length < lm) + idx = -2; + else + { + HnswQuery q; + + q.value = HnswGetValue(base, element); + + LoadElementsForInsert(neighbors, &q, &idx, index, support); + + if (idx == -1) + HnswUpdateConnection(base, neighbors, newElement, distance, lm, &idx, index, support); + } + + MemoryContextSwitchTo(oldCtx); + MemoryContextReset(updateCtx); + + return idx; +} + /* * Check if connection already exists */ @@ -354,14 +461,94 @@ ConnectionExists(HnswElement e, HnswNeighborTuple ntup, int startIdx, int lm) return false; } +/* + * Update neighbor + */ +static void +UpdateNeighborOnDisk(HnswElement element, HnswElement newElement, int idx, int m, int lm, int lc, Relation index, bool checkExisting, bool building) +{ + Buffer buf; + Page page; + GenericXLogState *state; + HnswNeighborTuple ntup; + int startIdx; + OffsetNumber offno = element->neighborOffno; + + /* Register page */ + buf = ReadBuffer(index, element->neighborPage); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + if (building) + { + state = NULL; + page = BufferGetPage(buf); + } + else + { + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + } + + /* Get tuple */ + ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, offno)); + + /* Calculate index for update */ + startIdx = (element->level - lc) * m; + + /* Check for existing connection */ + if (checkExisting && ConnectionExists(newElement, ntup, startIdx, lm)) + idx = -1; + else if (idx == -2) + { + /* Find free offset if still exists */ + /* TODO Retry updating connections if not */ + for (int j = 0; j < lm; j++) + { + if (!ItemPointerIsValid(&ntup->indextids[startIdx + j])) + { + idx = startIdx + j; + break; + } + } + } + else + idx += startIdx; + + /* Make robust to issues */ + if (idx >= 0 && idx < ntup->count) + { + ItemPointer indextid = &ntup->indextids[idx]; + + /* Update neighbor on the buffer */ + ItemPointerSet(indextid, newElement->blkno, newElement->offno); + + /* Commit */ + if (building) + MarkBufferDirty(buf); + else + GenericXLogFinish(state); + } + else if (!building) + GenericXLogAbort(state); + + UnlockReleaseBuffer(buf); +} + /* * Update neighbors */ void -HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement e, int m, bool checkExisting, bool building) +HnswUpdateNeighborsOnDisk(Relation index, HnswSupport * support, HnswElement e, int m, bool checkExisting, bool building) { char *base = NULL; + /* Use separate memory context to improve performance for larger vectors */ + MemoryContext updateCtx = GenerationContextCreate(CurrentMemoryContext, + "Hnsw insert update context", +#if PG_VERSION_NUM >= 150000 + 128 * 1024, 128 * 1024, +#endif + 128 * 1024); + for (int lc = e->level; lc >= 0; lc--) { int lm = HnswGetLayerM(m, lc); @@ -370,96 +557,20 @@ HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, Hns for (int i = 0; i < neighbors->length; i++) { HnswCandidate *hc = &neighbors->items[i]; - Buffer buf; - Page page; - GenericXLogState *state; - HnswNeighborTuple ntup; - int idx = -1; - int startIdx; HnswElement neighborElement = HnswPtrAccess(base, hc->element); - OffsetNumber offno = neighborElement->neighborOffno; + int idx; - /* - * Get latest neighbors since they may have changed. Do not lock - * yet since selecting neighbors can take time. Could use - * optimistic locking to retry if another update occurs before - * getting exclusive lock. - */ - HnswLoadNeighbors(neighborElement, index, m); - - /* - * Could improve performance for vacuuming by checking neighbors - * against list of elements being deleted to find index. It's - * important to exclude already deleted elements for this since - * they can be replaced at any time. - */ - - /* Select neighbors */ - HnswUpdateConnection(NULL, e, hc, lm, lc, &idx, index, procinfo, collation); + idx = GetUpdateIndex(neighborElement, e, hc->distance, m, lm, lc, index, support, updateCtx); /* New element was not selected as a neighbor */ if (idx == -1) continue; - /* Register page */ - buf = ReadBuffer(index, neighborElement->neighborPage); - LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); - if (building) - { - state = NULL; - page = BufferGetPage(buf); - } - else - { - state = GenericXLogStart(index); - page = GenericXLogRegisterBuffer(state, buf, 0); - } - - /* Get tuple */ - ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, offno)); - - /* Calculate index for update */ - startIdx = (neighborElement->level - lc) * m; - - /* Check for existing connection */ - if (checkExisting && ConnectionExists(e, ntup, startIdx, lm)) - idx = -1; - else if (idx == -2) - { - /* Find free offset if still exists */ - /* TODO Retry updating connections if not */ - for (int j = 0; j < lm; j++) - { - if (!ItemPointerIsValid(&ntup->indextids[startIdx + j])) - { - idx = startIdx + j; - break; - } - } - } - else - idx += startIdx; - - /* Make robust to issues */ - if (idx >= 0 && idx < ntup->count) - { - ItemPointer indextid = &ntup->indextids[idx]; - - /* Update neighbor on the buffer */ - ItemPointerSet(indextid, e->blkno, e->offno); - - /* Commit */ - if (building) - MarkBufferDirty(buf); - else - GenericXLogFinish(state); - } - else if (!building) - GenericXLogAbort(state); - - UnlockReleaseBuffer(buf); + UpdateNeighborOnDisk(neighborElement, e, idx, m, lm, lc, index, checkExisting, building); } } + + MemoryContextDelete(updateCtx); } /* @@ -549,7 +660,7 @@ FindDuplicateOnDisk(Relation index, HnswElement element, bool building) * Update graph on disk */ static void -UpdateGraphOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement element, int m, int efConstruction, HnswElement entryPoint, bool building) +UpdateGraphOnDisk(Relation index, HnswSupport * support, HnswElement element, int m, int efConstruction, HnswElement entryPoint, bool building) { BlockNumber newInsertPage = InvalidBlockNumber; @@ -565,7 +676,7 @@ UpdateGraphOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement HnswUpdateMetaPage(index, 0, NULL, newInsertPage, MAIN_FORKNUM, building); /* Update neighbors */ - HnswUpdateNeighborsOnDisk(index, procinfo, collation, element, m, false, building); + HnswUpdateNeighborsOnDisk(index, support, element, m, false, building); /* Update entry point if needed */ if (entryPoint == NULL || element->level > entryPoint->level) @@ -576,14 +687,12 @@ UpdateGraphOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement * Insert a tuple into the index */ bool -HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, ItemPointer heap_tid, bool building) +HnswInsertTupleOnDisk(Relation index, HnswSupport * support, Datum value, ItemPointer heaptid, bool building) { HnswElement entryPoint; HnswElement element; int m; int efConstruction = HnswGetEfConstruction(index); - FmgrInfo *procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); - Oid collation = index->rd_indcollation[0]; LOCKMODE lockmode = ShareLock; char *base = NULL; @@ -598,7 +707,7 @@ HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, HnswGetMetaPageInfo(index, &m, &entryPoint); /* Create an element */ - element = HnswInitElement(base, heap_tid, m, HnswGetMl(m), HnswGetMaxLevel(m), NULL); + element = HnswInitElement(base, heaptid, m, HnswGetMl(m), HnswGetMaxLevel(m), NULL); HnswPtrStore(base, element->value, DatumGetPointer(value)); /* Prevent concurrent inserts when likely updating entry point */ @@ -616,10 +725,10 @@ HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, } /* Find neighbors for element */ - HnswFindElementNeighbors(base, element, entryPoint, index, procinfo, collation, m, efConstruction, false); + HnswFindElementNeighbors(base, element, entryPoint, index, support, m, efConstruction, false); /* Update graph on disk */ - UpdateGraphOnDisk(index, procinfo, collation, element, m, efConstruction, entryPoint, building); + UpdateGraphOnDisk(index, support, element, m, efConstruction, entryPoint, building); /* Release lock */ UnlockPage(index, HNSW_UPDATE_LOCK, lockmode); @@ -631,31 +740,19 @@ HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, * Insert a tuple into the index */ static void -HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid) +HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid) { Datum value; const HnswTypeInfo *typeInfo = HnswGetTypeInfo(index); - FmgrInfo *normprocinfo; - Oid collation = index->rd_indcollation[0]; + HnswSupport support; - /* Detoast once for all calls */ - value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); + HnswInitSupport(&support, index); - /* Check value */ - if (typeInfo->checkValue != NULL) - typeInfo->checkValue(DatumGetPointer(value)); + /* Form index value */ + if (!HnswFormIndexValue(&value, values, isnull, typeInfo, &support)) + return; - /* Normalize if needed */ - normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); - if (normprocinfo != NULL) - { - if (!HnswCheckNorm(normprocinfo, collation, value)) - return; - - value = HnswNormValue(typeInfo, collation, value); - } - - HnswInsertTupleOnDisk(index, value, values, isnull, heap_tid, false); + HnswInsertTupleOnDisk(index, &support, value, heaptid, false); } /* diff --git a/src/hnswscan.c b/src/hnswscan.c index 30815af..ba3dea8 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -5,39 +5,74 @@ #include "pgstat.h" #include "storage/bufmgr.h" #include "storage/lmgr.h" +#include "utils/float.h" #include "utils/memutils.h" /* * Algorithm 5 from paper */ static List * -GetScanItems(IndexScanDesc scan, Datum q) +GetScanItems(IndexScanDesc scan, Datum value) { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; Relation index = scan->indexRelation; - FmgrInfo *procinfo = so->procinfo; - Oid collation = so->collation; + HnswSupport *support = &so->support; List *ep; List *w; int m; HnswElement entryPoint; char *base = NULL; + HnswQuery *q = &so->q; /* Get m and entry point */ HnswGetMetaPageInfo(index, &m, &entryPoint); + q->value = value; + so->m = m; + if (entryPoint == NULL) return NIL; - ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, procinfo, collation, false)); + ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, support, false)); for (int lc = entryPoint->level; lc >= 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL); + w = HnswSearchLayer(base, q, ep, 1, lc, index, support, m, false, NULL, NULL, NULL, true, NULL); ep = w; } - return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL); + return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, support, m, false, NULL, &so->v, hnsw_iterative_search != HNSW_ITERATIVE_SEARCH_OFF ? &so->discarded : NULL, true, &so->tuples); +} + +/* + * Resume scan at ground level with discarded candidates + */ +static List * +ResumeScanItems(IndexScanDesc scan) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + Relation index = scan->indexRelation; + List *ep = NIL; + char *base = NULL; + int batch_size = hnsw_ef_search; + + if (pairingheap_is_empty(so->discarded)) + return NIL; + + /* Get next batch of candidates */ + for (int i = 0; i < batch_size; i++) + { + HnswSearchCandidate *sc; + + if (pairingheap_is_empty(so->discarded)) + break; + + sc = HnswGetSearchCandidate(w_node, pairingheap_remove_first(so->discarded)); + + ep = lappend(ep, sc); + } + + return HnswSearchLayer(base, &so->q, ep, batch_size, 0, index, &so->support, so->m, false, NULL, &so->v, &so->discarded, false, &so->tuples); } /* @@ -60,13 +95,24 @@ GetScanValue(IndexScanDesc scan) Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); /* Normalize if needed */ - if (so->normprocinfo != NULL) - value = HnswNormValue(so->typeInfo, so->collation, value); + if (so->support.normprocinfo != NULL) + value = HnswNormValue(so->typeInfo, so->support.collation, value); } return value; } +#if defined(HNSW_MEMORY) +/* + * Show memory usage + */ +static void +ShowMemoryUsage(HnswScanOpaque so) +{ + elog(INFO, "memory: %zu KB, tuples: " INT64_FORMAT, MemoryContextMemAllocated(so->tmpCtx, false) / 1024, so->tuples); +} +#endif + /* * Prepare for an index scan */ @@ -81,14 +127,19 @@ hnswbeginscan(Relation index, int nkeys, int norderbys) so = (HnswScanOpaque) palloc(sizeof(HnswScanOpaqueData)); so->typeInfo = HnswGetTypeInfo(index); so->first = true; + so->v.tids = NULL; + so->discarded = NULL; + + /* + * Use a lower max allocation size than default to allow scanning more + * tuples for iterative search before exceeding work_mem + */ so->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, "Hnsw scan temporary context", - ALLOCSET_DEFAULT_SIZES); + 0, 8 * 1024, 512 * 1024); /* Set support functions */ - so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); - so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); - so->collation = index->rd_indcollation[0]; + HnswInitSupport(&so->support, index); scan->opaque = so; @@ -103,7 +154,15 @@ hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int no { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + if (so->v.tids != NULL) + tidhash_reset(so->v.tids); + + if (so->discarded != NULL) + pairingheap_reset(so->discarded); + so->first = true; + so->tuples = 0; + so->previousDistance = -get_float8_infinity(); MemoryContextReset(so->tmpCtx); if (keys && scan->numberOfKeys > 0) @@ -161,26 +220,104 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) so->first = false; #if defined(HNSW_MEMORY) - elog(INFO, "memory: %zu KB", MemoryContextMemAllocated(so->tmpCtx, false) / 1024); + ShowMemoryUsage(so); #endif } - while (list_length(so->w) > 0) + for (;;) { char *base = NULL; - HnswSearchCandidate *hc = llast(so->w); - HnswElement element = HnswPtrAccess(base, hc->element); + HnswSearchCandidate *sc; + HnswElement element; ItemPointer heaptid; + if (list_length(so->w) == 0) + { + if (hnsw_iterative_search == HNSW_ITERATIVE_SEARCH_OFF) + break; + + /* Empty index */ + if (so->discarded == NULL) + break; + + /* Reached max number of tuples */ + if (hnsw_max_search_tuples != -1 && so->tuples >= hnsw_max_search_tuples) + { + if (pairingheap_is_empty(so->discarded)) + break; + + /* Return remaining tuples */ + so->w = lappend(so->w, HnswGetSearchCandidate(w_node, pairingheap_remove_first(so->discarded))); + } + /* Prevent scans from consuming too much memory */ + else if (MemoryContextMemAllocated(so->tmpCtx, false) > (Size) work_mem * 1024L) + { + if (pairingheap_is_empty(so->discarded)) + { + ereport(DEBUG1, + (errmsg("hnsw index scan exceeded work_mem after " INT64_FORMAT " tuples", so->tuples), + errhint("Increase work_mem to scan more tuples."))); + + break; + } + + /* Return remaining tuples */ + so->w = lappend(so->w, HnswGetSearchCandidate(w_node, pairingheap_remove_first(so->discarded))); + } + else + { + /* + * Locking ensures when neighbors are read, the elements they + * reference will not be deleted (and replaced) during the + * iteration. + * + * Elements loaded into memory on previous iterations may have + * been deleted (and replaced), so when reading neighbors, the + * element version must be checked. + */ + LockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); + + so->w = ResumeScanItems(scan); + + UnlockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); + +#if defined(HNSW_MEMORY) + ShowMemoryUsage(so); +#endif + } + + if (list_length(so->w) == 0) + break; + } + + sc = llast(so->w); + element = HnswPtrAccess(base, sc->element); + /* Move to next element if no valid heap TIDs */ if (element->heaptidsLength == 0) { so->w = list_delete_last(so->w); + + /* Mark memory as free for next iteration */ + if (hnsw_iterative_search != HNSW_ITERATIVE_SEARCH_OFF) + { + pfree(element); + pfree(sc); + } + continue; } heaptid = &element->heaptids[--element->heaptidsLength]; + if (hnsw_iterative_search == HNSW_ITERATIVE_SEARCH_STRICT) + { + if (sc->distance < so->previousDistance) + continue; + + so->previousDistance = sc->distance; + } + MemoryContextSwitchTo(oldCtx); scan->xs_heaptid = *heaptid; diff --git a/src/hnswutils.c b/src/hnswutils.c index f105ea6..008f81e 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -101,19 +101,6 @@ hash_offset(Size offset) #define SH_DEFINE #include "lib/simplehash.h" -typedef union -{ - pointerhash_hash *pointers; - offsethash_hash *offsets; - tidhash_hash *tids; -} visited_hash; - -typedef union -{ - HnswElement element; - ItemPointerData indextid; -} HnswUnvisited; - /* * Get the max number of connections in an upper layer for each element in the index */ @@ -154,6 +141,17 @@ HnswOptionalProcInfo(Relation index, uint16 procnum) return index_getprocinfo(index, 1, procnum); } +/* + * Init support functions + */ +void +HnswInitSupport(HnswSupport * support, Relation index) +{ + support->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + support->collation = index->rd_indcollation[0]; + support->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); +} + /* * Normalize value */ @@ -170,9 +168,9 @@ HnswNormValue(const HnswTypeInfo * typeInfo, Oid collation, Datum value) * Check if non-zero norm */ bool -HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value) +HnswCheckNorm(HnswSupport * support, Datum value) { - return DatumGetFloat8(FunctionCall1Coll(procinfo, collation, value)) > 0; + return DatumGetFloat8(FunctionCall1Coll(support->normprocinfo, support->collation, value)) > 0; } /* @@ -201,7 +199,7 @@ HnswInitPage(Buffer buf, Page page) /* * Allocate a neighbor array */ -static HnswNeighborArray * +HnswNeighborArray * HnswInitNeighborArray(int lm, HnswAllocator * allocator) { HnswNeighborArray *a = HnswAlloc(allocator, HNSW_NEIGHBOR_ARRAY_SIZE(lm)); @@ -257,6 +255,8 @@ HnswInitElement(char *base, ItemPointer heaptid, int m, double ml, int maxLevel, element->level = level; element->deleted = 0; + /* Start at one to make it easier to find issues */ + element->version = 1; HnswInitNeighbors(base, element, m, allocator); @@ -398,6 +398,33 @@ HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, Bloc UnlockReleaseBuffer(buf); } +/* + * Form index value + */ +bool +HnswFormIndexValue(Datum *out, Datum *values, bool *isnull, const HnswTypeInfo * typeInfo, HnswSupport * support) +{ + /* Detoast once for all calls */ + Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); + + /* Check value */ + if (typeInfo->checkValue != NULL) + typeInfo->checkValue(DatumGetPointer(value)); + + /* Normalize if needed */ + if (support->normprocinfo != NULL) + { + if (!HnswCheckNorm(support, value)) + return false; + + value = HnswNormValue(typeInfo, support->collation, value); + } + + *out = value; + + return true; +} + /* * Set element tuple, except for neighbor info */ @@ -409,6 +436,7 @@ HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element) etup->type = HNSW_ELEMENT_TUPLE_TYPE; etup->level = element->level; etup->deleted = 0; + etup->version = element->version; for (int i = 0; i < HNSW_HEAPTIDS; i++) { if (i < element->heaptidsLength) @@ -451,69 +479,7 @@ HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m) } ntup->count = idx; -} - -/* - * Load neighbors from page - */ -static void -LoadNeighborsFromPage(HnswElement element, Relation index, Page page, int m) -{ - char *base = NULL; - - HnswNeighborTuple ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno)); - int neighborCount = (element->level + 2) * m; - - Assert(HnswIsNeighborTuple(ntup)); - - HnswInitNeighbors(base, element, m, NULL); - - /* Ensure expected neighbors */ - if (ntup->count != neighborCount) - return; - - for (int i = 0; i < neighborCount; i++) - { - HnswElement e; - int level; - HnswCandidate *hc; - ItemPointer indextid; - HnswNeighborArray *neighbors; - - indextid = &ntup->indextids[i]; - - if (!ItemPointerIsValid(indextid)) - continue; - - e = HnswInitElementFromBlock(ItemPointerGetBlockNumber(indextid), ItemPointerGetOffsetNumber(indextid)); - - /* Calculate level based on offset */ - level = element->level - i / m; - if (level < 0) - level = 0; - - neighbors = HnswGetNeighbors(base, element, level); - hc = &neighbors->items[neighbors->length++]; - HnswPtrStore(base, hc->element, e); - } -} - -/* - * Load neighbors - */ -void -HnswLoadNeighbors(HnswElement element, Relation index, int m) -{ - Buffer buf; - Page page; - - buf = ReadBuffer(index, element->neighborPage); - LockBuffer(buf, BUFFER_LOCK_SHARE); - page = BufferGetPage(buf); - - LoadNeighborsFromPage(element, index, page, m); - - UnlockReleaseBuffer(buf); + ntup->version = e->version; } /* @@ -524,6 +490,7 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe { element->level = etup->level; element->deleted = etup->deleted; + element->version = etup->version; element->neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid); element->neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid); element->heaptidsLength = 0; @@ -549,11 +516,20 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe } } +/* + * Calculate the distance between values + */ +static inline double +HnswGetDistance(Datum a, Datum b, HnswSupport * support) +{ + return DatumGetFloat8(FunctionCall2Coll(support->procinfo, support->collation, a, b)); +} + /* * Load an element and optionally get its distance from q */ static void -HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, float *maxDistance, HnswElement * element) +HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance, HnswElement * element) { Buffer buf; Page page; @@ -571,16 +547,16 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, float *distance, Datu /* Calculate distance */ if (distance != NULL) { - if (DatumGetPointer(*q) == NULL) + if (DatumGetPointer(q->value) == NULL) *distance = 0; else { - *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->data))); + *distance = HnswGetDistance(q->value, PointerGetDatum(&etup->data), support); /* Needed for intvec cosine distance */ /* TODO Improve */ if (isnan(*distance)) - *distance = FLT_MAX; + *distance = DBL_MAX; } } @@ -600,40 +576,51 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, float *distance, Datu * Load an element and optionally get its distance from q */ void -HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, float *maxDistance) +HnswLoadElement(HnswElement element, double *distance, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec, double *maxDistance) { - HnswLoadElementImpl(element->blkno, element->offno, distance, q, index, procinfo, collation, loadVec, maxDistance, &element); + HnswLoadElementImpl(element->blkno, element->offno, distance, q, index, support, loadVec, maxDistance, &element); } /* * Get the distance for an element */ -static float -GetElementDistance(char *base, HnswElement element, Datum q, FmgrInfo *procinfo, Oid collation) +static double +GetElementDistance(char *base, HnswElement element, HnswQuery * q, HnswSupport * support) { Datum value = HnswGetValue(base, element); - return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, value)); + return HnswGetDistance(q->value, value, support); +} + +/* + * Allocate a search candidate + */ +static HnswSearchCandidate * +HnswInitSearchCandidate(char *base, HnswElement element, double distance) +{ + HnswSearchCandidate *sc = palloc(sizeof(HnswSearchCandidate)); + + HnswPtrStore(base, sc->element, element); + sc->distance = distance; + return sc; } /* * Create a candidate for the entry point */ HnswSearchCandidate * -HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) +HnswEntryCandidate(char *base, HnswElement entryPoint, HnswQuery * q, Relation index, HnswSupport * support, bool loadVec) { - HnswSearchCandidate *hc = palloc(sizeof(HnswSearchCandidate)); + bool inMemory = index == NULL; + double distance; - HnswPtrStore(base, hc->element, entryPoint); - if (index == NULL) - hc->distance = GetElementDistance(base, entryPoint, q, procinfo, collation); + if (inMemory) + distance = GetElementDistance(base, entryPoint, q, support); else - HnswLoadElement(entryPoint, &hc->distance, &q, index, procinfo, collation, loadVec, NULL); - return hc; -} + HnswLoadElement(entryPoint, &distance, q, index, support, loadVec, NULL); -#define HnswGetSearchCandidate(membername, ptr) pairingheap_container(HnswSearchCandidate, membername, ptr) -#define HnswGetSearchCandidateConst(membername, ptr) pairingheap_const_container(HnswSearchCandidate, membername, ptr) + return HnswInitSearchCandidate(base, entryPoint, distance); +} /* * Compare candidate distances @@ -650,6 +637,21 @@ CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, v return 0; } +/* + * Compare discarded candidate distances + */ +static int +CompareNearestDiscardedCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) +{ + if (HnswGetSearchCandidateConst(w_node, a)->distance < HnswGetSearchCandidateConst(w_node, b)->distance) + return 1; + + if (HnswGetSearchCandidateConst(w_node, a)->distance > HnswGetSearchCandidateConst(w_node, b)->distance) + return -1; + + return 0; +} + /* * Compare candidate distances */ @@ -669,9 +671,9 @@ CompareFurthestCandidates(const pairingheap_node *a, const pairingheap_node *b, * Init visited */ static inline void -InitVisited(char *base, visited_hash * v, Relation index, int ef, int m) +InitVisited(char *base, visited_hash * v, bool inMemory, int ef, int m) { - if (index != NULL) + if (!inMemory) v->tids = tidhash_create(CurrentMemoryContext, ef * m * 2, NULL); else if (base != NULL) v->offsets = offsethash_create(CurrentMemoryContext, ef * m * 2, NULL); @@ -683,9 +685,9 @@ InitVisited(char *base, visited_hash * v, Relation index, int ef, int m) * Add to visited */ static inline void -AddToVisited(char *base, visited_hash * v, HnswElementPtr elementPtr, Relation index, bool *found) +AddToVisited(char *base, visited_hash * v, HnswElementPtr elementPtr, bool inMemory, bool *found) { - if (index != NULL) + if (!inMemory) { HnswElement element = HnswPtrAccess(base, elementPtr); ItemPointerData indextid; @@ -746,39 +748,61 @@ HnswLoadUnvisitedFromMemory(char *base, HnswElement element, HnswUnvisited * unv HnswCandidate *hc = &localNeighborhood->items[i]; bool found; - AddToVisited(base, v, hc->element, NULL, &found); + AddToVisited(base, v, hc->element, true, &found); if (!found) unvisited[(*unvisitedLength)++].element = HnswPtrAccess(base, hc->element); } } +/* + * Load neighbor index TIDs + */ +bool +HnswLoadNeighborTids(HnswElement element, ItemPointerData *indextids, Relation index, int m, int lm, int lc) +{ + Buffer buf; + Page page; + HnswNeighborTuple ntup; + int start; + + buf = ReadBuffer(index, element->neighborPage); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + + ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno)); + + /* + * Ensure the neighbor tuple has not been deleted or replaced between + * index scan iterations + */ + if (ntup->version != element->version || ntup->count != (element->level + 2) * m) + { + UnlockReleaseBuffer(buf); + return false; + } + + /* Copy to minimize lock time */ + start = (element->level - lc) * m; + memcpy(indextids, ntup->indextids + start, lm * sizeof(ItemPointerData)); + + UnlockReleaseBuffer(buf); + return true; +} + /* * Load unvisited neighbors from disk */ static void HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *unvisitedLength, visited_hash * v, Relation index, int m, int lm, int lc) { - Buffer buf; - Page page; - HnswNeighborTuple ntup; - int start; ItemPointerData indextids[HNSW_MAX_M * 2]; - buf = ReadBuffer(index, element->neighborPage); - LockBuffer(buf, BUFFER_LOCK_SHARE); - page = BufferGetPage(buf); - - ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno)); - start = (element->level - lc) * m; - - /* Copy to minimize lock time */ - memcpy(&indextids, ntup->indextids + start, lm * sizeof(ItemPointerData)); - - UnlockReleaseBuffer(buf); - *unvisitedLength = 0; + if (!HnswLoadNeighborTids(element, indextids, index, m, lm, lc)) + return; + for (int i = 0; i < lm; i++) { ItemPointer indextid = &indextids[i]; @@ -798,24 +822,37 @@ HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *u * Algorithm 2 from paper */ List * -HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement) +HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited, int64 *tuples) { List *w = NIL; pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL); int wlen = 0; - visited_hash v; + visited_hash vh; ListCell *lc2; HnswNeighborArray *localNeighborhood = NULL; Size neighborhoodSize = 0; int lm = HnswGetLayerM(m, lc); HnswUnvisited *unvisited = palloc(lm * sizeof(HnswUnvisited)); int unvisitedLength; + bool inMemory = index == NULL; - InitVisited(base, &v, index, ef, m); + if (v == NULL) + { + v = &vh; + initVisited = true; + } + + if (initVisited) + { + InitVisited(base, v, inMemory, ef, m); + + if (discarded != NULL) + *discarded = pairingheap_allocate(CompareNearestDiscardedCandidates, NULL); + } /* Create local memory for neighborhood if needed */ - if (index == NULL) + if (inMemory) { neighborhoodSize = HNSW_NEIGHBOR_ARRAY_SIZE(lm); localNeighborhood = palloc(neighborhoodSize); @@ -824,20 +861,26 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F /* Add entry points to v, C, and W */ foreach(lc2, ep) { - HnswSearchCandidate *hc = (HnswSearchCandidate *) lfirst(lc2); + HnswSearchCandidate *sc = (HnswSearchCandidate *) lfirst(lc2); bool found; - AddToVisited(base, &v, hc->element, index, &found); + if (initVisited) + { + AddToVisited(base, v, sc->element, inMemory, &found); - pairingheap_add(C, &hc->c_node); - pairingheap_add(W, &hc->w_node); + if (tuples != NULL) + (*tuples)++; + } + + pairingheap_add(C, &sc->c_node); + pairingheap_add(W, &sc->w_node); /* * Do not count elements being deleted towards ef when vacuuming. It * would be ideal to do this for inserts as well, but this could * affect insert performance. */ - if (CountElement(skipElement, HnswPtrAccess(base, hc->element))) + if (CountElement(skipElement, HnswPtrAccess(base, sc->element))) wlen++; } @@ -852,24 +895,27 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F cElement = HnswPtrAccess(base, c->element); - if (index == NULL) - HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, &v, lc, localNeighborhood, neighborhoodSize); + if (inMemory) + HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, v, lc, localNeighborhood, neighborhoodSize); else - HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, &v, index, m, lm, lc); + HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, v, index, m, lm, lc); + + if (tuples != NULL) + (*tuples) += unvisitedLength; for (int i = 0; i < unvisitedLength; i++) { HnswElement eElement; HnswSearchCandidate *e; - float eDistance; + double eDistance; bool alwaysAdd = wlen < ef; f = HnswGetSearchCandidate(w_node, pairingheap_first(W)); - if (index == NULL) + if (inMemory) { eElement = unvisited[i].element; - eDistance = GetElementDistance(base, eElement, q, procinfo, collation); + eDistance = GetElementDistance(base, eElement, q, support); } else { @@ -879,25 +925,30 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F /* Avoid any allocations if not adding */ eElement = NULL; - HnswLoadElementImpl(blkno, offno, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd ? NULL : &f->distance, &eElement); + HnswLoadElementImpl(blkno, offno, &eDistance, q, index, support, inserting, alwaysAdd || discarded != NULL ? NULL : &f->distance, &eElement); if (eElement == NULL) continue; } - if (!(eDistance < f->distance || alwaysAdd)) - continue; + if (eElement == NULL || !(eDistance < f->distance || alwaysAdd)) + { + if (discarded != NULL) + { + /* Create a new candidate */ + e = HnswInitSearchCandidate(base, eElement, eDistance); + pairingheap_add(*discarded, &e->w_node); + } - Assert(!eElement->deleted); + continue; + } /* Make robust to issues */ if (eElement->level < lc) continue; /* Create a new candidate */ - e = palloc(sizeof(HnswSearchCandidate)); - HnswPtrStore(base, e->element, eElement); - e->distance = eDistance; + e = HnswInitSearchCandidate(base, eElement, eDistance); pairingheap_add(C, &e->c_node); pairingheap_add(W, &e->w_node); @@ -912,7 +963,12 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F /* No need to decrement wlen */ if (wlen > ef) - pairingheap_remove_first(W); + { + HnswSearchCandidate *d = HnswGetSearchCandidate(w_node, pairingheap_remove_first(W)); + + if (discarded != NULL) + pairingheap_add(*discarded, &d->w_node); + } } } } @@ -920,9 +976,9 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F /* Add each element of W to w */ while (!pairingheap_is_empty(W)) { - HnswSearchCandidate *hc = HnswGetSearchCandidate(w_node, pairingheap_remove_first(W)); + HnswSearchCandidate *sc = HnswGetSearchCandidate(w_node, pairingheap_remove_first(W)); - w = lappend(w, hc); + w = lappend(w, sc); } return w; @@ -976,32 +1032,22 @@ CompareCandidateDistancesOffset(const ListCell *a, const ListCell *b) return 0; } -/* - * Calculate the distance between elements - */ -static float -HnswGetDistance(char *base, HnswElement a, HnswElement b, FmgrInfo *procinfo, Oid collation) -{ - Datum aValue = HnswGetValue(base, a); - Datum bValue = HnswGetValue(base, b); - - return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, aValue, bValue)); -} - /* * Check if an element is closer to q than any element from R */ static bool -CheckElementCloser(char *base, HnswCandidate * e, List *r, FmgrInfo *procinfo, Oid collation) +CheckElementCloser(char *base, HnswCandidate * e, List *r, HnswSupport * support) { HnswElement eElement = HnswPtrAccess(base, e->element); + Datum eValue = HnswGetValue(base, eElement); ListCell *lc2; foreach(lc2, r) { HnswCandidate *ri = lfirst(lc2); HnswElement riElement = HnswPtrAccess(base, ri->element); - float distance = HnswGetDistance(base, eElement, riElement, procinfo, collation); + Datum riValue = HnswGetValue(base, riElement); + float distance = HnswGetDistance(eValue, riValue, support); if (distance <= e->distance) return false; @@ -1014,15 +1060,14 @@ CheckElementCloser(char *base, HnswCandidate * e, List *r, FmgrInfo *procinfo, O * Algorithm 4 from paper */ static List * -SelectNeighbors(char *base, List *c, int lm, int lc, FmgrInfo *procinfo, Oid collation, HnswElement e2, HnswCandidate * newCandidate, HnswCandidate * *pruned, bool sortCandidates) +SelectNeighbors(char *base, List *c, int lm, HnswSupport * support, bool *closerSet, HnswCandidate * newCandidate, HnswCandidate * *pruned, bool sortCandidates) { List *r = NIL; List *w = list_copy(c); HnswCandidate **wd; int wdlen = 0; int wdoff = 0; - HnswNeighborArray *neighbors = HnswGetNeighbors(base, e2, lc); - bool mustCalculate = !neighbors->closerSet; + bool mustCalculate = !(*closerSet); List *added = NIL; bool removedAny = false; @@ -1049,7 +1094,7 @@ SelectNeighbors(char *base, List *c, int lm, int lc, FmgrInfo *procinfo, Oid col /* Use previous state of r and wd to skip work when possible */ if (mustCalculate) - e->closer = CheckElementCloser(base, e, r, procinfo, collation); + e->closer = CheckElementCloser(base, e, r, support); else if (list_length(added) > 0) { /* Keep Valgrind happy for in-memory, parallel builds */ @@ -1062,7 +1107,7 @@ SelectNeighbors(char *base, List *c, int lm, int lc, FmgrInfo *procinfo, Oid col */ if (e->closer) { - e->closer = CheckElementCloser(base, e, added, procinfo, collation); + e->closer = CheckElementCloser(base, e, added, support); if (!e->closer) removedAny = true; @@ -1075,7 +1120,7 @@ SelectNeighbors(char *base, List *c, int lm, int lc, FmgrInfo *procinfo, Oid col */ if (removedAny) { - e->closer = CheckElementCloser(base, e, r, procinfo, collation); + e->closer = CheckElementCloser(base, e, r, support); if (e->closer) added = lappend(added, e); } @@ -1083,7 +1128,7 @@ SelectNeighbors(char *base, List *c, int lm, int lc, FmgrInfo *procinfo, Oid col } else if (e == newCandidate) { - e->closer = CheckElementCloser(base, e, r, procinfo, collation); + e->closer = CheckElementCloser(base, e, r, support); if (e->closer) added = lappend(added, e); } @@ -1099,7 +1144,7 @@ SelectNeighbors(char *base, List *c, int lm, int lc, FmgrInfo *procinfo, Oid col } /* Cached value can only be used in future if sorted deterministically */ - neighbors->closerSet = sortCandidates; + *closerSet = sortCandidates; /* Keep pruned connections */ while (wdoff < wdlen && list_length(r) < lm) @@ -1134,18 +1179,16 @@ AddConnections(char *base, HnswElement element, List *neighbors, int lc) * Update connections */ void -HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm, int lc, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation) +HnswUpdateConnection(char *base, HnswNeighborArray * neighbors, HnswElement newElement, float distance, int lm, int *updateIdx, Relation index, HnswSupport * support) { - HnswElement hce = HnswPtrAccess(base, hc->element); - HnswNeighborArray *currentNeighbors = HnswGetNeighbors(base, hce, lc); - HnswCandidate hc2; + HnswCandidate newHc; - HnswPtrStore(base, hc2.element, element); - hc2.distance = hc->distance; + HnswPtrStore(base, newHc.element, newElement); + newHc.distance = distance; - if (currentNeighbors->length < lm) + if (neighbors->length < lm) { - currentNeighbors->items[currentNeighbors->length++] = hc2; + neighbors->items[neighbors->length++] = newHc; /* Track update */ if (updateIdx != NULL) @@ -1154,54 +1197,26 @@ HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm else { /* Shrink connections */ + List *c = NIL; HnswCandidate *pruned = NULL; - /* Load elements on insert */ - if (index != NULL) - { - Datum q = HnswGetValue(base, hce); + /* Add candidates */ + for (int i = 0; i < neighbors->length; i++) + c = lappend(c, &neighbors->items[i]); + c = lappend(c, &newHc); - for (int i = 0; i < currentNeighbors->length; i++) - { - HnswCandidate *hc3 = ¤tNeighbors->items[i]; - HnswElement hc3Element = HnswPtrAccess(base, hc3->element); - - if (HnswPtrIsNull(base, hc3Element->value)) - HnswLoadElement(hc3Element, &hc3->distance, &q, index, procinfo, collation, true, NULL); - else - hc3->distance = GetElementDistance(base, hc3Element, q, procinfo, collation); - - /* Prune element if being deleted */ - if (hc3Element->heaptidsLength == 0) - { - pruned = ¤tNeighbors->items[i]; - break; - } - } - } + SelectNeighbors(base, c, lm, support, &neighbors->closerSet, &newHc, &pruned, true); + /* Should not happen */ if (pruned == NULL) - { - List *c = NIL; - - /* Add candidates */ - for (int i = 0; i < currentNeighbors->length; i++) - c = lappend(c, ¤tNeighbors->items[i]); - c = lappend(c, &hc2); - - SelectNeighbors(base, c, lm, lc, procinfo, collation, hce, &hc2, &pruned, true); - - /* Should not happen */ - if (pruned == NULL) - return; - } + return; /* Find and replace the pruned element */ - for (int i = 0; i < currentNeighbors->length; i++) + for (int i = 0; i < neighbors->length; i++) { - if (HnswPtrEqual(base, currentNeighbors->items[i].element, pruned->element)) + if (HnswPtrEqual(base, neighbors->items[i].element, pruned->element)) { - currentNeighbors->items[i] = hc2; + neighbors->items[i] = newHc; /* Track update */ if (updateIdx != NULL) @@ -1261,17 +1276,20 @@ PrecomputeHash(char *base, HnswElement element) * Algorithm 1 from paper */ void -HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, bool existing) +HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, HnswSupport * support, int m, int efConstruction, bool existing) { List *ep; List *w; int level = element->level; int entryLevel; - Datum q = HnswGetValue(base, element); + HnswQuery q; HnswElement skipElement = existing ? element : NULL; + bool inMemory = index == NULL; + + q.value = HnswGetValue(base, element); /* Precompute hash */ - if (index == NULL) + if (inMemory) PrecomputeHash(base, element); /* No neighbors if no entry point */ @@ -1279,13 +1297,13 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint return; /* Get entry point and level */ - ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, procinfo, collation, true)); + ep = list_make1(HnswEntryCandidate(base, entryPoint, &q, index, support, true)); entryLevel = entryPoint->level; /* 1st phase: greedy search to insert level */ for (int lc = entryLevel; lc >= level + 1; lc--) { - w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement); + w = HnswSearchLayer(base, &q, ep, 1, lc, index, support, m, true, skipElement, NULL, NULL, true, NULL); ep = w; } @@ -1304,7 +1322,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint List *lw = NIL; ListCell *lc2; - w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement); + w = HnswSearchLayer(base, &q, ep, efConstruction, lc, index, support, m, true, skipElement, NULL, NULL, true, NULL); /* Convert search candidates to candidates */ foreach(lc2, w) @@ -1320,7 +1338,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint /* Elements being deleted or skipped can help with search */ /* but should be removed before selecting neighbors */ - if (index != NULL) + if (!inMemory) lw = RemoveElements(base, lw, skipElement); /* @@ -1328,7 +1346,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint * sortCandidates to true for in-memory builds to enable closer * caching, but there does not seem to be a difference in performance. */ - neighbors = SelectNeighbors(base, lw, lm, lc, procinfo, collation, element, NULL, NULL, false); + neighbors = SelectNeighbors(base, lw, lm, support, &HnswGetNeighbors(base, element, lc)->closerSet, NULL, NULL, false); AddConnections(base, element, neighbors, lc); diff --git a/src/hnswvacuum.c b/src/hnswvacuum.c index 67cc645..251d9d9 100644 --- a/src/hnswvacuum.c +++ b/src/hnswvacuum.c @@ -184,13 +184,12 @@ static void RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element, HnswElement entryPoint) { Relation index = vacuumstate->index; + HnswSupport *support = &vacuumstate->support; Buffer buf; Page page; GenericXLogState *state; int m = vacuumstate->m; int efConstruction = vacuumstate->efConstruction; - FmgrInfo *procinfo = vacuumstate->procinfo; - Oid collation = vacuumstate->collation; BufferAccessStrategy bas = vacuumstate->bas; HnswNeighborTuple ntup = vacuumstate->ntup; Size ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, m); @@ -205,7 +204,7 @@ RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element, HnswEleme element->heaptidsLength = 0; /* Find neighbors for element, skipping itself */ - HnswFindElementNeighbors(base, element, entryPoint, index, procinfo, collation, m, efConstruction, true); + HnswFindElementNeighbors(base, element, entryPoint, index, support, m, efConstruction, true); /* Zero memory for each element */ MemSet(ntup, 0, HNSW_TUPLE_ALLOC_SIZE); @@ -229,7 +228,7 @@ RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element, HnswEleme UnlockReleaseBuffer(buf); /* Update neighbors */ - HnswUpdateNeighborsOnDisk(index, procinfo, collation, element, m, true, false); + HnswUpdateNeighborsOnDisk(index, support, element, m, true, false); } /* @@ -239,6 +238,7 @@ static void RepairGraphEntryPoint(HnswVacuumState * vacuumstate) { Relation index = vacuumstate->index; + HnswSupport *support = &vacuumstate->support; HnswElement highestPoint = &vacuumstate->highestPoint; HnswElement entryPoint; MemoryContext oldCtx = MemoryContextSwitchTo(vacuumstate->tmpCtx); @@ -256,7 +256,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) LockPage(index, HNSW_UPDATE_LOCK, ShareLock); /* Load element */ - HnswLoadElement(highestPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true, NULL); + HnswLoadElement(highestPoint, NULL, NULL, index, support, true, NULL); /* Repair if needed */ if (NeedsUpdated(vacuumstate, highestPoint)) @@ -294,7 +294,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) * is outdated, this can remove connections at higher levels in * the graph until they are repaired, but this should be fine. */ - HnswLoadElement(entryPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true, NULL); + HnswLoadElement(entryPoint, NULL, NULL, index, support, true, NULL); if (NeedsUpdated(vacuumstate, entryPoint)) { @@ -527,6 +527,14 @@ MarkDeleted(HnswVacuumState * vacuumstate) for (int i = 0; i < ntup->count; i++) ItemPointerSetInvalid(&ntup->indextids[i]); + /* Increment version */ + /* This is used to avoid incorrect reads for iterative scans */ + /* Reserve some bits for future use */ + etup->version++; + if (etup->version > 15) + etup->version = 1; + ntup->version = etup->version; + /* * We modified the tuples in place, no need to call * PageIndexTupleOverwrite @@ -573,13 +581,13 @@ InitVacuumState(HnswVacuumState * vacuumstate, IndexVacuumInfo *info, IndexBulkD vacuumstate->callback_state = callback_state; vacuumstate->efConstruction = HnswGetEfConstruction(index); vacuumstate->bas = GetAccessStrategy(BAS_BULKREAD); - vacuumstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); - vacuumstate->collation = index->rd_indcollation[0]; vacuumstate->ntup = palloc0(HNSW_TUPLE_ALLOC_SIZE); vacuumstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, "Hnsw vacuum temporary context", ALLOCSET_DEFAULT_SIZES); + HnswInitSupport(&vacuumstate->support, index); + /* Get m from metapage */ HnswGetMetaPageInfo(index, &vacuumstate->m, NULL); diff --git a/src/ivfbuild.c b/src/ivfbuild.c index 85a247f..54a5be5 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -228,11 +228,11 @@ BuildCallback(Relation index, ItemPointer tid, Datum *values, static inline void GetNextTuple(Tuplesortstate *sortstate, TupleDesc tupdesc, TupleTableSlot *slot, IndexTuple *itup, int *list) { - Datum value; - bool isnull; - if (tuplesort_gettupleslot(sortstate, true, false, slot, NULL)) { + Datum value; + bool isnull; + *list = DatumGetInt32(slot_getattr(slot, 1, &isnull)); value = slot_getattr(slot, 3, &isnull); @@ -254,8 +254,8 @@ InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum) IndexTuple itup = NULL; /* silence compiler warning */ int64 inserted = 0; - TupleTableSlot *slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsMinimalTuple); - TupleDesc tupdesc = RelationGetDescr(index); + TupleTableSlot *slot = MakeSingleTupleTableSlot(buildstate->sortdesc, &TTSOpsMinimalTuple); + TupleDesc tupdesc = buildstate->tupdesc; pgstat_progress_update_param(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_IVFFLAT_PHASE_LOAD); @@ -319,6 +319,7 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In buildstate->index = index; buildstate->indexInfo = indexInfo; buildstate->typeInfo = IvfflatGetTypeInfo(index); + buildstate->tupdesc = RelationGetDescr(index); buildstate->lists = IvfflatGetLists(index); buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; @@ -356,12 +357,12 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In errmsg("dimensions must be greater than one for this opclass"))); /* Create tuple description for sorting */ - buildstate->tupdesc = CreateTemplateTupleDesc(3); - TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 1, "list", INT4OID, -1, 0); - TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 2, "tid", TIDOID, -1, 0); - TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 3, "vector", RelationGetDescr(index)->attrs[0].atttypid, -1, 0); + buildstate->sortdesc = CreateTemplateTupleDesc(3); + TupleDescInitEntry(buildstate->sortdesc, (AttrNumber) 1, "list", INT4OID, -1, 0); + TupleDescInitEntry(buildstate->sortdesc, (AttrNumber) 2, "tid", TIDOID, -1, 0); + TupleDescInitEntry(buildstate->sortdesc, (AttrNumber) 3, "vector", buildstate->tupdesc->attrs[0].atttypid, -1, 0); - buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual); + buildstate->slot = MakeSingleTupleTableSlot(buildstate->sortdesc, &TTSOpsVirtual); buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, buildstate->typeInfo->itemSize(buildstate->dimensions)); buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists); @@ -633,7 +634,7 @@ IvfflatParallelScanAndSort(IvfflatSpool * ivfspool, IvfflatShared * ivfshared, S InitBuildState(&buildstate, ivfspool->heap, ivfspool->index, indexInfo); memcpy(buildstate.centers->items, ivfcenters, buildstate.centers->itemsize * buildstate.centers->maxlen); buildstate.centers->length = buildstate.centers->maxlen; - ivfspool->sortstate = InitBuildSortState(buildstate.tupdesc, sortmem, coordinate); + ivfspool->sortstate = InitBuildSortState(buildstate.sortdesc, sortmem, coordinate); buildstate.sortstate = ivfspool->sortstate; scan = table_beginscan_parallel(ivfspool->heap, ParallelTableScanFromIvfflatShared(ivfshared)); @@ -950,7 +951,7 @@ AssignTuples(IvfflatBuildState * buildstate) } /* Begin serial/leader tuplesort */ - buildstate->sortstate = InitBuildSortState(buildstate->tupdesc, maintenance_work_mem, coordinate); + buildstate->sortstate = InitBuildSortState(buildstate->sortdesc, maintenance_work_mem, coordinate); /* Add tuples to sort */ if (buildstate->heap != NULL) diff --git a/src/ivfflat.c b/src/ivfflat.c index 4e9b9a4..be29dd5 100644 --- a/src/ivfflat.c +++ b/src/ivfflat.c @@ -17,8 +17,16 @@ #endif int ivfflat_probes; +int ivfflat_iterative_search; +int ivfflat_max_probes; static relopt_kind ivfflat_relopt_kind; +static const struct config_enum_entry ivfflat_iterative_search_options[] = { + {"off", IVFFLAT_ITERATIVE_SEARCH_OFF, false}, + {"relaxed_order", IVFFLAT_ITERATIVE_SEARCH_RELAXED, false}, + {NULL, 0, false} +}; + /* * Initialize index options and variables */ @@ -33,6 +41,15 @@ IvfflatInit(void) "Valid range is 1..lists.", &ivfflat_probes, IVFFLAT_DEFAULT_PROBES, IVFFLAT_MIN_LISTS, IVFFLAT_MAX_LISTS, PGC_USERSET, 0, NULL, NULL, NULL); + DefineCustomEnumVariable("ivfflat.iterative_search", "Sets the iterative search mode", + NULL, &ivfflat_iterative_search, + IVFFLAT_ITERATIVE_SEARCH_OFF, ivfflat_iterative_search_options, PGC_USERSET, 0, NULL, NULL, NULL); + + /* If this is less than probes, probes is used */ + DefineCustomIntVariable("ivfflat.max_probes", "Sets the max number of probes for iterative search", + "-1 means no limit", &ivfflat_max_probes, + -1, -1, IVFFLAT_MAX_LISTS, PGC_USERSET, 0, NULL, NULL, NULL); + MarkGUCPrefixReserved("ivfflat"); } @@ -69,6 +86,8 @@ ivfflatcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, GenericCosts costs; int lists; double ratio; + double sequentialRatio = 0.5; + double startupPages; double spc_seq_page_cost; Relation index; @@ -85,6 +104,8 @@ ivfflatcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, MemSet(&costs, 0, sizeof(costs)); + genericcostestimate(root, path, loop_count, &costs); + index = index_open(path->indexinfo->indexoid, NoLock); IvfflatGetMetaPageInfo(index, &lists, NULL); index_close(index, NoLock); @@ -94,41 +115,26 @@ ivfflatcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, if (ratio > 1.0) ratio = 1.0; - /* - * This gives us the subset of tuples to visit. This value is passed into - * the generic cost estimator to determine the number of pages to visit - * during the index scan. - */ - costs.numIndexTuples = path->indexinfo->tuples * ratio; - - genericcostestimate(root, path, loop_count, &costs); - get_tablespace_page_costs(path->indexinfo->reltablespace, NULL, &spc_seq_page_cost); + /* Change some page cost from random to sequential */ + costs.indexTotalCost -= sequentialRatio * costs.numIndexPages * (costs.spc_random_page_cost - spc_seq_page_cost); + + /* Startup cost is cost before returning the first row */ + costs.indexStartupCost = costs.indexTotalCost * ratio; + /* Adjust cost if needed since TOAST not included in seq scan cost */ - if (costs.numIndexPages > path->indexinfo->rel->pages && ratio < 0.5) + startupPages = costs.numIndexPages * ratio; + if (startupPages > path->indexinfo->rel->pages && ratio < 0.5) { - /* Change all page cost from random to sequential */ - costs.indexTotalCost -= costs.numIndexPages * (costs.spc_random_page_cost - spc_seq_page_cost); + /* Change rest of page cost from random to sequential */ + costs.indexStartupCost -= (1 - sequentialRatio) * startupPages * (costs.spc_random_page_cost - spc_seq_page_cost); /* Remove cost of extra pages */ - costs.indexTotalCost -= (costs.numIndexPages - path->indexinfo->rel->pages) * spc_seq_page_cost; - } - else - { - /* Change some page cost from random to sequential */ - costs.indexTotalCost -= 0.5 * costs.numIndexPages * (costs.spc_random_page_cost - spc_seq_page_cost); + costs.indexStartupCost -= (startupPages - path->indexinfo->rel->pages) * spc_seq_page_cost; } - /* - * If the list selectivity is lower than what is returned from the generic - * cost estimator, use that. - */ - if (ratio < costs.indexSelectivity) - costs.indexSelectivity = ratio; - - /* Use total cost since most work happens before first tuple is returned */ - *indexStartupCost = costs.indexTotalCost; + *indexStartupCost = costs.indexStartupCost; *indexTotalCost = costs.indexTotalCost; *indexSelectivity = costs.indexSelectivity; *indexCorrelation = costs.indexCorrelation; diff --git a/src/ivfflat.h b/src/ivfflat.h index 8518317..91753c8 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -80,6 +80,14 @@ /* Variables */ extern int ivfflat_probes; +extern int ivfflat_iterative_search; +extern int ivfflat_max_probes; + +typedef enum IvfflatIterativeSearchMode +{ + IVFFLAT_ITERATIVE_SEARCH_OFF, + IVFFLAT_ITERATIVE_SEARCH_RELAXED +} IvfflatIterativeSearchMode; typedef struct VectorArrayData { @@ -165,6 +173,7 @@ typedef struct IvfflatBuildState Relation index; IndexInfo *indexInfo; const IvfflatTypeInfo *typeInfo; + TupleDesc tupdesc; /* Settings */ int dimensions; @@ -198,7 +207,7 @@ typedef struct IvfflatBuildState /* Sorting */ Tuplesortstate *sortstate; - TupleDesc tupdesc; + TupleDesc sortdesc; TupleTableSlot *slot; /* Memory */ @@ -247,8 +256,11 @@ typedef struct IvfflatScanOpaqueData { const IvfflatTypeInfo *typeInfo; int probes; + int maxProbes; int dimensions; bool first; + Datum value; + MemoryContext tmpCtx; /* Sorting */ Tuplesortstate *sortstate; @@ -265,7 +277,9 @@ typedef struct IvfflatScanOpaqueData /* Lists */ pairingheap *listQueue; - IvfflatScanList lists[FLEXIBLE_ARRAY_MEMBER]; /* must come last */ + BlockNumber *listPages; + int listIndex; + IvfflatScanList *lists; } IvfflatScanOpaqueData; typedef IvfflatScanOpaqueData * IvfflatScanOpaque; diff --git a/src/ivfinsert.c b/src/ivfinsert.c index b748c5e..014c9be 100644 --- a/src/ivfinsert.c +++ b/src/ivfinsert.c @@ -98,7 +98,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, R IvfflatGetMetaPageInfo(index, NULL, NULL); /* Find the insert page - sets the page and list info */ - FindInsertPage(index, values, &insertPage, &listInfo); + FindInsertPage(index, &value, &insertPage, &listInfo); Assert(BlockNumberIsValid(insertPage)); originalInsertPage = insertPage; diff --git a/src/ivfscan.c b/src/ivfscan.c index 74e3675..251f70f 100644 --- a/src/ivfscan.c +++ b/src/ivfscan.c @@ -10,10 +10,7 @@ #include "miscadmin.h" #include "pgstat.h" #include "storage/bufmgr.h" - -#ifdef IVFFLAT_MEMORY #include "utils/memutils.h" -#endif #define GetScanList(ptr) pairingheap_container(IvfflatScanList, ph_node, ptr) #define GetScanListConst(ptr) pairingheap_const_container(IvfflatScanList, ph_node, ptr) @@ -65,7 +62,7 @@ GetScanLists(IndexScanDesc scan, Datum value) /* Use procinfo from the index instead of scan key for performance */ distance = DatumGetFloat8(so->distfunc(so->procinfo, so->collation, PointerGetDatum(&list->center), value)); - if (listCount < so->probes) + if (listCount < so->maxProbes) { IvfflatScanList *scanlist; @@ -78,7 +75,7 @@ GetScanLists(IndexScanDesc scan, Datum value) pairingheap_add(so->listQueue, &scanlist->ph_node); /* Calculate max distance */ - if (listCount == so->probes) + if (listCount == so->maxProbes) maxDistance = GetScanList(pairingheap_first(so->listQueue))->distance; } else if (distance < maxDistance) @@ -102,6 +99,11 @@ GetScanLists(IndexScanDesc scan, Datum value) UnlockReleaseBuffer(cbuf); } + + for (int i = listCount - 1; i >= 0; i--) + so->listPages[i] = GetScanList(pairingheap_remove_first(so->listQueue))->startPage; + + Assert(pairingheap_is_empty(so->listQueue)); } /* @@ -114,11 +116,14 @@ GetScanItems(IndexScanDesc scan, Datum value) TupleDesc tupdesc = RelationGetDescr(scan->indexRelation); double tuples = 0; TupleTableSlot *slot = so->vslot; + int batchProbes = 0; + + tuplesort_reset(so->sortstate); /* Search closest probes lists */ - while (!pairingheap_is_empty(so->listQueue)) + while (so->listIndex < so->maxProbes && (++batchProbes) <= so->probes) { - BlockNumber searchPage = GetScanList(pairingheap_remove_first(so->listQueue))->startPage; + BlockNumber searchPage = so->listPages[so->listIndex++]; /* Search all entry pages for list */ while (BlockNumberIsValid(searchPage)) @@ -166,13 +171,17 @@ GetScanItems(IndexScanDesc scan, Datum value) } } - if (tuples < 100) + if (tuples < 100 && ivfflat_iterative_search == IVFFLAT_ITERATIVE_SEARCH_OFF) ereport(DEBUG1, (errmsg("index scan found few tuples"), errdetail("Index may have been created with little data."), errhint("Recreate the index and possibly decrease lists."))); tuplesort_performsort(so->sortstate); + +#if defined(IVFFLAT_MEMORY) + elog(INFO, "memory: %zu MB", MemoryContextMemAllocated(CurrentMemoryContext, true) / (1024 * 1024)); +#endif } /* @@ -209,7 +218,13 @@ GetScanValue(IndexScanDesc scan) /* Normalize if needed */ if (so->normprocinfo != NULL) + { + MemoryContext oldCtx = MemoryContextSwitchTo(so->tmpCtx); + value = IvfflatNormValue(so->typeInfo, so->collation, value); + + MemoryContextSwitchTo(oldCtx); + } } return value; @@ -240,19 +255,40 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) int lists; int dimensions; int probes = ivfflat_probes; + int maxProbes; + MemoryContext oldCtx; scan = RelationGetIndexScan(index, nkeys, norderbys); /* Get lists and dimensions from metapage */ IvfflatGetMetaPageInfo(index, &lists, &dimensions); + if (ivfflat_iterative_search != IVFFLAT_ITERATIVE_SEARCH_OFF) + { + maxProbes = ivfflat_max_probes; + + if (maxProbes < 0) + maxProbes = lists; + else if (maxProbes < probes) + { + /* TODO Show notice */ + maxProbes = probes; + } + } + else + maxProbes = probes; + if (probes > lists) probes = lists; - so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + probes * sizeof(IvfflatScanList)); + if (maxProbes > lists) + maxProbes = lists; + + so = (IvfflatScanOpaque) palloc(sizeof(IvfflatScanOpaqueData)); so->typeInfo = IvfflatGetTypeInfo(index); so->first = true; so->probes = probes; + so->maxProbes = maxProbes; so->dimensions = dimensions; /* Set support functions */ @@ -260,6 +296,12 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) so->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); so->collation = index->rd_indcollation[0]; + so->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, + "Ivfflat scan temporary context", + ALLOCSET_DEFAULT_SIZES); + + oldCtx = MemoryContextSwitchTo(so->tmpCtx); + /* Create tuple description for sorting */ so->tupdesc = CreateTemplateTupleDesc(2); TupleDescInitEntry(so->tupdesc, (AttrNumber) 1, "distance", FLOAT8OID, -1, 0); @@ -280,6 +322,11 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) so->bas = GetAccessStrategy(BAS_BULKREAD); so->listQueue = pairingheap_allocate(CompareLists, scan); + so->listPages = palloc(maxProbes * sizeof(BlockNumber)); + so->listIndex = 0; + so->lists = palloc(maxProbes * sizeof(IvfflatScanList)); + + MemoryContextSwitchTo(oldCtx); scan->opaque = so; @@ -294,11 +341,9 @@ ivfflatrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int { IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; - if (!so->first) - tuplesort_reset(so->sortstate); - so->first = true; pairingheap_reset(so->listQueue); + so->listIndex = 0; if (keys && scan->numberOfKeys > 0) memmove(scan->keyData, keys, scan->numberOfKeys * sizeof(ScanKeyData)); @@ -314,6 +359,8 @@ bool ivfflatgettuple(IndexScanDesc scan, ScanDirection dir) { IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; + ItemPointer heaptid; + bool isnull; /* * Index can be used to scan backward, but Postgres doesn't support @@ -341,28 +388,23 @@ ivfflatgettuple(IndexScanDesc scan, ScanDirection dir) IvfflatBench("GetScanLists", GetScanLists(scan, value)); IvfflatBench("GetScanItems", GetScanItems(scan, value)); so->first = false; - -#if defined(IVFFLAT_MEMORY) - elog(INFO, "memory: %zu MB", MemoryContextMemAllocated(CurrentMemoryContext, true) / (1024 * 1024)); -#endif - - /* Clean up if we allocated a new value */ - if (value != scan->orderByData->sk_argument) - pfree(DatumGetPointer(value)); + so->value = value; } - if (tuplesort_gettupleslot(so->sortstate, true, false, so->mslot, NULL)) + while (!tuplesort_gettupleslot(so->sortstate, true, false, so->mslot, NULL)) { - bool isnull; - ItemPointer heaptid = (ItemPointer) DatumGetPointer(slot_getattr(so->mslot, 2, &isnull)); + if (so->listIndex == so->maxProbes) + return false; - scan->xs_heaptid = *heaptid; - scan->xs_recheck = false; - scan->xs_recheckorderby = false; - return true; + IvfflatBench("GetScanItems", GetScanItems(scan, so->value)); } - return false; + heaptid = (ItemPointer) DatumGetPointer(slot_getattr(so->mslot, 2, &isnull)); + + scan->xs_heaptid = *heaptid; + scan->xs_recheck = false; + scan->xs_recheckorderby = false; + return true; } /* @@ -373,12 +415,10 @@ ivfflatendscan(IndexScanDesc scan) { IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; - pairingheap_free(so->listQueue); + /* Free any temporary files */ tuplesort_end(so->sortstate); - FreeAccessStrategy(so->bas); - FreeTupleDesc(so->tupdesc); - /* TODO Free vslot and mslot without freeing TupleDesc */ + MemoryContextDelete(so->tmpCtx); pfree(so); scan->opaque = NULL; diff --git a/src/ivfvacuum.c b/src/ivfvacuum.c index 57815af..1272da8 100644 --- a/src/ivfvacuum.c +++ b/src/ivfvacuum.c @@ -26,7 +26,7 @@ ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, Page cpage; OffsetNumber coffno; OffsetNumber cmaxoffno; - BlockNumber startPages[MaxOffsetNumber]; + BlockNumber listPages[MaxOffsetNumber]; ListInfo listInfo; cbuf = ReadBuffer(index, blkno); @@ -40,7 +40,7 @@ ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, { IvfflatList list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, coffno)); - startPages[coffno - FirstOffsetNumber] = list->startPage; + listPages[coffno - FirstOffsetNumber] = list->startPage; } listInfo.blkno = blkno; @@ -50,7 +50,7 @@ ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, for (coffno = FirstOffsetNumber; coffno <= cmaxoffno; coffno = OffsetNumberNext(coffno)) { - BlockNumber searchPage = startPages[coffno - FirstOffsetNumber]; + BlockNumber searchPage = listPages[coffno - FirstOffsetNumber]; BlockNumber insertPage = InvalidBlockNumber; /* Iterate over entry pages */ diff --git a/src/vector.c b/src/vector.c index facc07e..a5b2aac 100644 --- a/src/vector.c +++ b/src/vector.c @@ -155,24 +155,6 @@ CheckStateArray(ArrayType *statearray, const char *caller) return (float8 *) ARR_DATA_PTR(statearray); } -#if PG_VERSION_NUM < 120003 -static pg_noinline void -float_overflow_error(void) -{ - ereport(ERROR, - (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), - errmsg("value out of range: overflow"))); -} - -static pg_noinline void -float_underflow_error(void) -{ - ereport(ERROR, - (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), - errmsg("value out of range: underflow"))); -} -#endif - /* * Convert textual representation to internal representation */ diff --git a/test/expected/hnsw_vector.out b/test/expected/hnsw_vector.out index cbda5fa..60eb011 100644 --- a/test/expected/hnsw_vector.out +++ b/test/expected/hnsw_vector.out @@ -99,6 +99,32 @@ SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <+> (SELECT NULL::vector)) t2 4 (1 row) +DROP TABLE t; +-- iterative +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_l2_ops); +SET hnsw.iterative_search = strict_order; +SET hnsw.ef_search = 1; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] + [1,1,1] + [0,0,0] +(3 rows) + +SET hnsw.iterative_search = relaxed_order; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] + [1,1,1] + [0,0,0] +(3 rows) + +RESET hnsw.iterative_search; +RESET hnsw.ef_search; DROP TABLE t; -- unlogged CREATE UNLOGGED TABLE t (val vector(3)); @@ -139,4 +165,21 @@ SET hnsw.ef_search = 0; ERROR: 0 is outside the valid range for parameter "hnsw.ef_search" (1 .. 1000) SET hnsw.ef_search = 1001; ERROR: 1001 is outside the valid range for parameter "hnsw.ef_search" (1 .. 1000) +SHOW hnsw.iterative_search; + hnsw.iterative_search +----------------------- + off +(1 row) + +SET hnsw.iterative_search = on; +ERROR: invalid value for parameter "hnsw.iterative_search": "on" +HINT: Available values: off, relaxed_order, strict_order. +SHOW hnsw.max_search_tuples; + hnsw.max_search_tuples +------------------------ + -1 +(1 row) + +SET hnsw.max_search_tuples = -2; +ERROR: -2 is outside the valid range for parameter "hnsw.max_search_tuples" (-1 .. 2147483647) DROP TABLE t; diff --git a/test/expected/ivfflat_vector.out b/test/expected/ivfflat_vector.out index 84871b4..8a80ea3 100644 --- a/test/expected/ivfflat_vector.out +++ b/test/expected/ivfflat_vector.out @@ -81,6 +81,44 @@ SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2 3 (1 row) +DROP TABLE t; +-- iterative +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 3); +SET ivfflat.iterative_search = relaxed_order; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] + [1,1,1] + [0,0,0] +(3 rows) + +SET ivfflat.max_probes = 0; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] +(1 row) + +SET ivfflat.max_probes = 1; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] +(1 row) + +SET ivfflat.max_probes = 2; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] + [1,1,1] +(2 rows) + +RESET ivfflat.iterative_search; +RESET ivfflat.max_probes; DROP TABLE t; -- unlogged CREATE UNLOGGED TABLE t (val vector(3)); @@ -109,4 +147,27 @@ SHOW ivfflat.probes; 1 (1 row) +SET ivfflat.probes = 0; +ERROR: 0 is outside the valid range for parameter "ivfflat.probes" (1 .. 32768) +SET ivfflat.probes = 32769; +ERROR: 32769 is outside the valid range for parameter "ivfflat.probes" (1 .. 32768) +SHOW ivfflat.iterative_search; + ivfflat.iterative_search +-------------------------- + off +(1 row) + +SET ivfflat.iterative_search = on; +ERROR: invalid value for parameter "ivfflat.iterative_search": "on" +HINT: Available values: off, relaxed_order. +SHOW ivfflat.max_probes; + ivfflat.max_probes +-------------------- + -1 +(1 row) + +SET ivfflat.max_probes = -2; +ERROR: -2 is outside the valid range for parameter "ivfflat.max_probes" (-1 .. 32768) +SET ivfflat.max_probes = 32769; +ERROR: 32769 is outside the valid range for parameter "ivfflat.max_probes" (-1 .. 32768) DROP TABLE t; diff --git a/test/sql/hnsw_vector.sql b/test/sql/hnsw_vector.sql index b7896cf..ba8d4aa 100644 --- a/test/sql/hnsw_vector.sql +++ b/test/sql/hnsw_vector.sql @@ -57,6 +57,23 @@ SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <+> (SELECT NULL::vector)) t2 DROP TABLE t; +-- iterative + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_l2_ops); + +SET hnsw.iterative_search = strict_order; +SET hnsw.ef_search = 1; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + +SET hnsw.iterative_search = relaxed_order; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + +RESET hnsw.iterative_search; +RESET hnsw.ef_search; +DROP TABLE t; + -- unlogged CREATE UNLOGGED TABLE t (val vector(3)); @@ -81,4 +98,12 @@ SHOW hnsw.ef_search; SET hnsw.ef_search = 0; SET hnsw.ef_search = 1001; +SHOW hnsw.iterative_search; + +SET hnsw.iterative_search = on; + +SHOW hnsw.max_search_tuples; + +SET hnsw.max_search_tuples = -2; + DROP TABLE t; diff --git a/test/sql/ivfflat_vector.sql b/test/sql/ivfflat_vector.sql index 32759e3..9e20060 100644 --- a/test/sql/ivfflat_vector.sql +++ b/test/sql/ivfflat_vector.sql @@ -44,6 +44,28 @@ SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2 DROP TABLE t; +-- iterative + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 3); + +SET ivfflat.iterative_search = relaxed_order; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + +SET ivfflat.max_probes = 0; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + +SET ivfflat.max_probes = 1; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + +SET ivfflat.max_probes = 2; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + +RESET ivfflat.iterative_search; +RESET ivfflat.max_probes; +DROP TABLE t; + -- unlogged CREATE UNLOGGED TABLE t (val vector(3)); @@ -62,4 +84,16 @@ CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 32769); SHOW ivfflat.probes; +SET ivfflat.probes = 0; +SET ivfflat.probes = 32769; + +SHOW ivfflat.iterative_search; + +SET ivfflat.iterative_search = on; + +SHOW ivfflat.max_probes; + +SET ivfflat.max_probes = -2; +SET ivfflat.max_probes = 32769; + DROP TABLE t; diff --git a/test/t/002_ivfflat_vacuum.pl b/test/t/002_ivfflat_vacuum.pl index d4cfeaf..a7f1d9e 100644 --- a/test/t/002_ivfflat_vacuum.pl +++ b/test/t/002_ivfflat_vacuum.pl @@ -6,13 +6,7 @@ use Test::More; my $dim = 3; -my @r = (); -for (1 .. $dim) -{ - my $v = int(rand(1000)) + 1; - push(@r, "i % $v"); -} -my $array_sql = join(", ", @r); +my $array_sql = join(",", ('random()') x $dim); # Initialize node my $node = PostgreSQL::Test::Cluster->new('node'); @@ -23,19 +17,20 @@ $node->start; $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); $node->safe_psql("postgres", - "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" ); $node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v vector_l2_ops);"); # Get size my $size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');"); +# Store values +$node->safe_psql("postgres", "CREATE TABLE tmp AS SELECT * FROM tst;"); + # Delete all, vacuum, and insert same data $node->safe_psql("postgres", "DELETE FROM tst;"); $node->safe_psql("postgres", "VACUUM tst;"); -$node->safe_psql("postgres", - "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" -); +$node->safe_psql("postgres", "INSERT INTO tst SELECT * FROM tmp;"); # Check size my $new_size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');"); diff --git a/test/t/009_ivfflat_filtering.pl b/test/t/009_ivfflat_filtering.pl index efe0866..72b2c53 100644 --- a/test/t/009_ivfflat_filtering.pl +++ b/test/t/009_ivfflat_filtering.pl @@ -94,8 +94,7 @@ like($explain, qr/Seq Scan/); $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE v <-> '$query' < 1 ORDER BY v <-> '$query'; )); -# TODO Do not use index -like($explain, qr/Index Scan using idx/); +like($explain, qr/Seq Scan/); # Test attribute index $node->safe_psql("postgres", "CREATE INDEX attribute_idx ON tst (c);"); @@ -110,7 +109,6 @@ $node->safe_psql("postgres", "CREATE INDEX partial_idx ON tst USING ivfflat (v v $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE c = $c ORDER BY v <-> '$query' LIMIT $limit; )); -# TODO Use partial index -like($explain, qr/Index Scan using idx/); +like($explain, qr/Index Scan using partial_idx/); done_testing(); diff --git a/test/t/017_hnsw_filtering.pl b/test/t/017_hnsw_filtering.pl index 249b32d..afa2a1c 100644 --- a/test/t/017_hnsw_filtering.pl +++ b/test/t/017_hnsw_filtering.pl @@ -18,9 +18,13 @@ $node->start; # Create table and index $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim), c int4, t text);"); +$node->safe_psql("postgres", "CREATE TABLE cat (i int4 PRIMARY KEY, t text, b boolean);"); $node->safe_psql("postgres", "INSERT INTO tst SELECT i, ARRAY[$array_sql], i % $nc, 'test ' || i FROM generate_series(1, 10000) i;" ); +$node->safe_psql("postgres", + "INSERT INTO cat SELECT i, 'cat ' || i, i % 5 = 0 FROM generate_series(1, $nc) i;" +); $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING hnsw (v vector_l2_ops);"); $node->safe_psql("postgres", "ANALYZE tst;"); @@ -37,8 +41,7 @@ my $c = int(rand() * $nc); my $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE c = $c ORDER BY v <-> '$query' LIMIT $limit; )); -# TODO Do not use index -like($explain, qr/Index Scan using idx/); +like($explain, qr/Seq Scan/); # Test attribute filtering with few rows removed $explain = $node->safe_psql("postgres", qq( @@ -56,8 +59,7 @@ like($explain, qr/Index Scan using idx/); $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE c < 1 ORDER BY v <-> '$query' LIMIT $limit; )); -# TODO Do not use index -like($explain, qr/Index Scan using idx/); +like($explain, qr/Seq Scan/); # Test attribute filtering with few rows removed like $explain = $node->safe_psql("postgres", qq( @@ -96,13 +98,25 @@ $explain = $node->safe_psql("postgres", qq( )); like($explain, qr/Seq Scan/); +# Test join +$explain = $node->safe_psql("postgres", qq( + EXPLAIN ANALYZE SELECT cat.t FROM cat INNER JOIN tst ON cat.i = tst.c ORDER BY v <-> '$query' LIMIT $limit; +)); +like($explain, qr/Index Scan using idx/); + +# Test join with attribute filtering +$explain = $node->safe_psql("postgres", qq( + EXPLAIN ANALYZE SELECT cat.t FROM cat INNER JOIN tst ON cat.i = tst.c WHERE cat.b = 't' ORDER BY v <-> '$query' LIMIT $limit; +)); +like($explain, qr/Index Scan using idx/); + # Test attribute index $node->safe_psql("postgres", "CREATE INDEX attribute_idx ON tst (c);"); $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE c = $c ORDER BY v <-> '$query' LIMIT $limit; )); -# TODO Use attribute index -like($explain, qr/Index Scan using idx/); +# Use attribute index +like($explain, qr/Bitmap Index Scan on attribute_idx/); # Test partial index $node->safe_psql("postgres", "CREATE INDEX partial_idx ON tst USING hnsw (v vector_l2_ops) WHERE (c = $c);"); diff --git a/test/t/039_hnsw_cost.pl b/test/t/039_hnsw_cost.pl new file mode 100644 index 0000000..97ea5e7 --- /dev/null +++ b/test/t/039_hnsw_cost.pl @@ -0,0 +1,60 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my @dims = (384, 1536); +my $limit = 10; + +# Initialize node +my $node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); + +for my $dim (@dims) +{ + my $array_sql = join(",", ('random()') x $dim); + + # Create table and index + $node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); + $node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 2000) i;" + ); + $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING hnsw (v vector_l2_ops);"); + $node->safe_psql("postgres", "ANALYZE tst;"); + + # Generate query + my @r = (); + for (1 .. $dim) + { + push(@r, rand()); + } + my $query = "[" . join(",", @r) . "]"; + + my $explain = $node->safe_psql("postgres", qq( + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v <-> '$query' LIMIT $limit; + )); + like($explain, qr/Index Scan using idx/); + + # 3x the rows are needed for distance filters + # since the planner uses DEFAULT_INEQ_SEL for the selectivity (should be 1) + # Recreate index for performance + $node->safe_psql("postgres", "DROP INDEX idx;"); + $node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(2001, 6000) i;" + ); + $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING hnsw (v vector_l2_ops);"); + $node->safe_psql("postgres", "ANALYZE tst;"); + + $explain = $node->safe_psql("postgres", qq( + EXPLAIN ANALYZE SELECT i FROM tst WHERE v <-> '$query' < 1 ORDER BY v <-> '$query' LIMIT $limit; + )); + like($explain, qr/Index Scan using idx/); + + $node->safe_psql("postgres", "DROP TABLE tst;"); +} + +done_testing(); diff --git a/test/t/040_ivfflat_cost.pl b/test/t/040_ivfflat_cost.pl new file mode 100644 index 0000000..1c311a3 --- /dev/null +++ b/test/t/040_ivfflat_cost.pl @@ -0,0 +1,50 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my @dims = (384, 1536); +my $limit = 10; + +# Initialize node +my $node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); + +for my $dim (@dims) +{ + my $array_sql = join(",", ('random()') x $dim); + + # Create table and index + $node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); + $node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 5000) i;" + ); + $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING ivfflat (v vector_l2_ops) WITH (lists = 5);"); + $node->safe_psql("postgres", "ANALYZE tst;"); + + # Generate query + my @r = (); + for (1 .. $dim) + { + push(@r, rand()); + } + my $query = "[" . join(",", @r) . "]"; + + my $explain = $node->safe_psql("postgres", qq( + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v <-> '$query' LIMIT $limit; + )); + like($explain, qr/Index Scan using idx/); + + $explain = $node->safe_psql("postgres", qq( + EXPLAIN ANALYZE SELECT i FROM tst WHERE v <-> '$query' < 1 ORDER BY v <-> '$query' LIMIT $limit; + )); + like($explain, qr/Index Scan using idx/); + + $node->safe_psql("postgres", "DROP TABLE tst;"); +} + +done_testing(); diff --git a/test/t/041_ivfflat_iterative_search.pl b/test/t/041_ivfflat_iterative_search.pl new file mode 100644 index 0000000..6e0f721 --- /dev/null +++ b/test/t/041_ivfflat_iterative_search.pl @@ -0,0 +1,54 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $dim = 3; +my $array_sql = join(",", ('random()') x $dim); + +# Initialize node +my $node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4 PRIMARY KEY, v vector($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" +); +$node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v vector_l2_ops);"); + +my $count = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = 10; + SET ivfflat.iterative_search = relaxed_order; + SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t; +)); +is($count, 10); + +foreach ((30, 50, 70)) +{ + my $max_probes = $_; + my $expected = $max_probes / 10; + my $sum = 0; + + for my $i (1 .. 20) + { + $count = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = 10; + SET ivfflat.iterative_search = relaxed_order; + SET ivfflat.max_probes = $max_probes; + SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst WHERE i = $i) LIMIT 11) t; + )); + $sum += $count; + } + + my $avg = $sum / 20; + cmp_ok($avg, '>', $expected - 2); + cmp_ok($avg, '<', $expected + 2); +} + +done_testing(); diff --git a/test/t/042_ivfflat_iterative_search_recall.pl b/test/t/042_ivfflat_iterative_search_recall.pl new file mode 100644 index 0000000..b6844b5 --- /dev/null +++ b/test/t/042_ivfflat_iterative_search_recall.pl @@ -0,0 +1,125 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; +my @cs = (100, 1000); + +sub test_recall +{ + my ($c, $probes, $min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = $probes; + SET ivfflat.iterative_search = relaxed_order; + EXPLAIN ANALYZE SELECT i FROM tst WHERE i % $c = 0 ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan using idx on tst/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = $probes; + SET ivfflat.iterative_search = relaxed_order; + SELECT i FROM tst WHERE i % $c = 0 ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + + my @expected_ids = split("\n", $expected[$i]); + my %expected_set = map { $_ => 1 } @expected_ids; + + foreach (@actual_ids) + { + if (exists($expected_set{$_})) + { + $correct++; + } + } + + $total += $limit; + } + + cmp_ok($correct / $total, ">=", $min, "$operator $c"); +} + +# Initialize node +$node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 100000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my $r1 = rand(); + my $r2 = rand(); + my $r3 = rand(); + push(@queries, "[$r1,$r2,$r3]"); +} + +# Check each index type +my @operators = ("<->", "<=>"); +my @opclasses = ("vector_l2_ops", "vector_cosine_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING ivfflat (v $opclass);"); + + foreach (@cs) + { + my $c = $_; + + # Get exact results + @expected = (); + foreach (@queries) + { + my $res = $node->safe_psql("postgres", qq( + SET enable_indexscan = off; + WITH top AS ( + SELECT v $operator '$_' AS distance FROM tst WHERE i % $c = 0 ORDER BY distance LIMIT $limit + ) + SELECT i FROM tst WHERE (v $operator '$_') <= (SELECT MAX(distance) FROM top) + )); + push(@expected, $res); + } + + if ($c == 100) + { + test_recall($c, 1, 0.57, $operator); + test_recall($c, 10, 0.98, $operator); + } + else + { + if ($operator eq "<->") + { + test_recall($c, 1, 0.80, $operator); + } + else + { + test_recall($c, 1, 0.88, $operator); + } + } + } + + $node->safe_psql("postgres", "DROP INDEX idx;"); +} + +done_testing(); diff --git a/test/t/043_hnsw_iterative_search.pl b/test/t/043_hnsw_iterative_search.pl new file mode 100644 index 0000000..8e1aa1e --- /dev/null +++ b/test/t/043_hnsw_iterative_search.pl @@ -0,0 +1,67 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $dim = 3; +my $array_sql = join(",", ('random()') x $dim); + +# Initialize node +my $node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4 PRIMARY KEY, v vector($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" +); +$node->safe_psql("postgres", qq( + SET maintenance_work_mem = '128MB'; + SET max_parallel_maintenance_workers = 2; + CREATE INDEX ON tst USING hnsw (v vector_l2_ops) +)); + +my $count = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.iterative_search = relaxed_order; + SET work_mem = '8MB'; + SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t; +)); +is($count, 10); + +foreach ((30000, 50000, 70000)) +{ + my $max_tuples = $_; + my $expected = $max_tuples / 10000; + my $sum = 0; + + for my $i (1 .. 20) + { + $count = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.iterative_search = relaxed_order; + SET hnsw.max_search_tuples = $max_tuples; + SET work_mem = '8MB'; + SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst WHERE i = $i) LIMIT 11) t; + )); + $sum += $count; + } + + my $avg = $sum / 20; + cmp_ok($avg, '>', $expected - 2); + cmp_ok($avg, '<', $expected + 2); +} + +my ($ret, $stdout, $stderr) = $node->psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.iterative_search = relaxed_order; + SET client_min_messages = debug1; + SET work_mem = '2MB'; + SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t; +)); +like($stderr, qr/hnsw index scan exceeded work_mem after \d+ tuples/); + +done_testing(); diff --git a/test/t/044_hnsw_iterative_search_recall.pl b/test/t/044_hnsw_iterative_search_recall.pl new file mode 100644 index 0000000..911fccb --- /dev/null +++ b/test/t/044_hnsw_iterative_search_recall.pl @@ -0,0 +1,118 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; +my $dim = 3; +my $array_sql = join(",", ('random()') x $dim); +my @cs = (50, 500); + +sub test_recall +{ + my ($c, $ef_search, $min, $operator, $mode) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.ef_search = $ef_search; + SET hnsw.iterative_search = $mode; + EXPLAIN ANALYZE SELECT i FROM tst WHERE i % $c = 0 ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan using idx on tst/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.ef_search = $ef_search; + SET hnsw.iterative_search = $mode; + SELECT i FROM tst WHERE i % $c = 0 ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + + my @expected_ids = split("\n", $expected[$i]); + my %expected_set = map { $_ => 1 } @expected_ids; + + foreach (@actual_ids) + { + if (exists($expected_set{$_})) + { + $correct++; + } + } + + $total += $limit; + } + + cmp_ok($correct / $total, ">=", $min, "$operator $mode $c"); +} + +# Initialize node +$node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 50000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my @r = (); + for (1 .. $dim) + { + push(@r, rand()); + } + push(@queries, "[" . join(",", @r) . "]"); +} + +# Check each index type +my @operators = ("<->", "<=>"); +my @opclasses = ("vector_l2_ops", "vector_cosine_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + $node->safe_psql("postgres", qq( + SET maintenance_work_mem = '128MB'; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + )); + + foreach (@cs) + { + my $c = $_; + + # Get exact results + @expected = (); + foreach (@queries) + { + my $res = $node->safe_psql("postgres", qq( + SET enable_indexscan = off; + WITH top AS ( + SELECT v $operator '$_' AS distance FROM tst WHERE i % $c = 0 ORDER BY distance LIMIT $limit + ) + SELECT i FROM tst WHERE (v $operator '$_') <= (SELECT MAX(distance) FROM top) + )); + push(@expected, $res); + } + + test_recall($c, 40, 0.99, $operator, "strict_order"); + test_recall($c, 40, 0.99, $operator, "relaxed_order"); + } + + $node->safe_psql("postgres", "DROP INDEX idx;"); +} + +done_testing();