commit 6df7fa05b243860ed1a3553730246921dc5cb917 Author: Andrew Kane Date: Tue Apr 20 14:04:28 2021 -0700 First commit diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..0cf6f7f --- /dev/null +++ b/.editorconfig @@ -0,0 +1,6 @@ +root = true + +[*.{c,h}] +indent_style = tab +indent_size = tab +tab_width = 4 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..13d1116 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,20 @@ +name: build +on: [push, pull_request] +jobs: + build: + if: "!contains(github.event.head_commit.message, '[skip ci]')" + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + postgres: [13, 12, 11, 10, 9.6] + steps: + - uses: actions/checkout@v2 + - uses: ankane/setup-postgres@v1 + with: + postgres-version: ${{ matrix.postgres }} + - run: sudo apt-get install postgresql-server-dev-${{ matrix.postgres }} libipc-run-perl + - run: make + - run: sudo make install + - run: make installcheck + - run: make prove_installcheck diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d161065 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +results +tmp_check +regression.* +*.o +*.so diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..fdbb576 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,3 @@ +## 0.1.0 (unreleased) + +- First release diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d4236ce --- /dev/null +++ b/LICENSE @@ -0,0 +1,20 @@ +Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + +Portions Copyright (c) 1994, The Regents of the University of California + +Permission to use, copy, modify, and distribute this software and its +documentation for any purpose, without fee, and without a written agreement +is hereby granted, provided that the above copyright notice and this +paragraph and the following two paragraphs appear in all copies. + +IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR +DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING +LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS +DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + +THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, +INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY +AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS +ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO +PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..d149cc5 --- /dev/null +++ b/Makefile @@ -0,0 +1,16 @@ +EXTENSION = vector +DATA = vector--0.1.0.sql +MODULE_big = vector +OBJS = ivfbuild.o ivfflat.o ivfinsert.o ivfkmeans.o ivfscan.o ivfutils.o ivfvacuum.o vector.o + +TESTS = $(wildcard sql/*.sql) + +REGRESS = btree cast functions ivfflat_cosine ivfflat_ip ivfflat_l2 ivfflat_unlogged vector + +PG_CONFIG ?= pg_config +PGXS := $(shell $(PG_CONFIG) --pgxs) +include $(PGXS) + +prove_installcheck: + rm -rf $(CURDIR)/tmp_check + cd $(srcdir) && TESTDIR='$(CURDIR)' PATH="$(bindir):$$PATH" PGPORT='6$(DEF_PGPORT)' PG_REGRESS='$(top_builddir)/src/test/regress/pg_regress' $(PROVE) $(PG_PROVE_FLAGS) $(PROVE_FLAGS) $(if $(PROVE_TESTS),$(PROVE_TESTS),t/*.pl) diff --git a/README.md b/README.md new file mode 100644 index 0000000..614de0d --- /dev/null +++ b/README.md @@ -0,0 +1,189 @@ +# pgvector + +Open-source vector similarity search for Postgres + +```sql +CREATE TABLE table (column vector(3)); +CREATE INDEX ON table USING ivfflat (column); +SELECT * FROM table ORDER BY column <-> '[1,2,3]' LIMIT 5; +``` + +Supports L2 distance, inner product, and cosine distance + +[![Build Status](https://github.com/ankane/pgvector/workflows/build/badge.svg?branch=master)](https://github.com/ankane/pgvector/actions) + +## Installation + +Compile and install the extension (supports Postgres 9.6+) + +```sh +git clone https://github.com/ankane/pgvector.git +cd pgvector +make +make install # may need sudo +``` + +Then load it in databases where you want to use it + +```sql +CREATE EXTENSION vector; +``` + +## Getting Started + +Create a vector column with 3 dimensions (replace `table` and `column` with non-reserved names) + +```sql +CREATE TABLE table (column vector(3)); +``` + +Insert values + +```sql +INSERT INTO table VALUES ('[1,2,3]'), ('[4,5,6]'); +``` + +Get the nearest neighbor by L2 distance + +```sql +SELECT * FROM table ORDER BY column <-> '[3,1,2]' LIMIT 1; +``` + +Also supports inner product (`<#>`) and cosine distance (`<=>`) + +Note: `<#>` returns the negative inner product since Postgres only supports `ASC` order index scans on operators + +## Indexing + +Speed up queries with an approximate index. Add an index for each distance function you want to use. + +L2 distance + +```sql +CREATE INDEX ON table USING ivfflat (column); +``` + +Inner product + +```sql +CREATE INDEX ON table USING ivfflat (column vector_ip_ops); +``` + +Cosine distance + +```sql +CREATE INDEX ON table USING ivfflat (column vector_cosine_ops); +``` + +Indexes should be created after the table has data for optimal clustering. Also, unlike typical indexes which only affect performance, you may see different results for queries after adding an approximate index. + +### Index Options + +Specify the number of inverted lists (100 by default) + +```sql +CREATE INDEX ON table USING ivfflat (column) WITH (lists = 100); +``` + +### Query Options + +Specify the number of probes (1 by default) + +```sql +SET ivfflat.probes = 1; +``` + +A higher value improves recall at the cost of speed. + +Use `SET LOCAL` inside a transaction to set it for a single query + +```sql +BEGIN; +SET LOCAL ivfflat.probes = 1; +SELECT ... +COMMIT; +``` + +## Reference + +### Vector Type + +Each vector takes `4 * dimensions + 8` bytes of storage. Each element is a float, and all elements must be finite (no `NaN`, `Infinity` or `-Infinity`). Vectors can have up to 1024 dimensions. + +### Vector Operators + +Operator | Description +--- | --- +\+ | element-wise addition +\- | element-wise subtraction +<-> | Euclidean distance +<#> | negative inner product +<=> | cosine distance + +### Vector Functions + +Function | Description +--- | --- +cosine_distance(vector, vector) | cosine distance +inner_product(vector, vector) | inner product +l2_distance(vector, vector) | Euclidean distance +vector_dims(vector) | number of dimensions +vector_norm(vector) | Euclidean norm + +## Thanks + +Thanks to: + +- [PASE: PostgreSQL Ultra-High-Dimensional Approximate Nearest Neighbor Search Extension](https://dl.acm.org/doi/pdf/10.1145/3318464.3386131) +- [Faiss: A Library for Efficient Similarity Search and Clustering of Dense Vectors](https://github.com/facebookresearch/faiss) +- [Using the Triangle Inequality to Accelerate k-means](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf) +- [k-means++: The Advantage of Careful Seeding](https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf) +- [Concept Decompositions for Large Sparse Text Data using Clustering](https://www.cs.utexas.edu/users/inderjit/public_papers/concept_mlj.pdf) + +## History + +View the [changelog](https://github.com/ankane/pgvector/blob/master/CHANGELOG.md) + +## Contributing + +Everyone is encouraged to help improve this project. Here are a few ways you can help: + +- [Report bugs](https://github.com/ankane/pgvector/issues) +- Fix bugs and [submit pull requests](https://github.com/ankane/pgvector/pulls) +- Write, clarify, or fix documentation +- Suggest or add new features + +To get started with development: + +```sh +git clone https://github.com/ankane/pgvector.git +cd pgvector +make +make install +``` + +To run all tests: + +```sh +make installcheck # regression tests +make prove_installcheck # TAP tests +``` + +To run single tests: + +```sh +make installcheck REGRESS=vector # regression test +make prove_installcheck PROVE_TESTS=t/001_wal.pl # TAP test +``` + +Directories + +- `expected` - expected output for regression tests +- `sql` - regression tests +- `t` - TAP tests + +Resources for contributors + +- [Extension Building Infrastructure](https://www.postgresql.org/docs/current/extend-pgxs.html) +- [Index Access Method Interface Definition](https://www.postgresql.org/docs/current/indexam.html) +- [Generic WAL Records](https://www.postgresql.org/docs/13/generic-wal.html) diff --git a/expected/btree.out b/expected/btree.out new file mode 100644 index 0000000..4a215ce --- /dev/null +++ b/expected/btree.out @@ -0,0 +1,19 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t (val); +SELECT * FROM t WHERE val = '[1,2,3]'; + val +--------- + [1,2,3] +(1 row) + +SELECT * FROM t ORDER BY val LIMIT 1; + val +--------- + [0,0,0] +(1 row) + +DROP TABLE t; diff --git a/expected/cast.out b/expected/cast.out new file mode 100644 index 0000000..1cef001 --- /dev/null +++ b/expected/cast.out @@ -0,0 +1,30 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; +SELECT ARRAY[1,2,3]::vector; + array +--------- + [1,2,3] +(1 row) + +SELECT ARRAY[1,2,3]::float4[]::vector; + array +--------- + [1,2,3] +(1 row) + +SELECT ARRAY[1,2,3]::float8[]::vector; + array +--------- + [1,2,3] +(1 row) + +SELECT '{NULL}'::real[]::vector; +ERROR: array must not containing NULLs +SELECT '{NaN}'::real[]::vector; +ERROR: NaN not allowed in vector +SELECT '{Infinity}'::real[]::vector; +ERROR: infinite value not allowed in vector +SELECT '{-Infinity}'::real[]::vector; +ERROR: infinite value not allowed in vector +SELECT '{}'::real[]::vector; +ERROR: vector must have at least 1 dimension diff --git a/expected/functions.out b/expected/functions.out new file mode 100644 index 0000000..930a60e --- /dev/null +++ b/expected/functions.out @@ -0,0 +1,56 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; +SELECT '[1,2,3]'::vector + '[4,5,6]'; + ?column? +---------- + [5,7,9] +(1 row) + +SELECT '[1,2,3]'::vector - '[4,5,6]'; + ?column? +------------ + [-3,-3,-3] +(1 row) + +SELECT vector_dims('[1,2,3]'); + vector_dims +------------- + 3 +(1 row) + +SELECT round(vector_norm('[1,1]')::numeric, 5); + round +--------- + 1.41421 +(1 row) + +SELECT round(l2_distance('[1,2]', '[0,0]')::numeric, 5); + round +--------- + 2.23607 +(1 row) + +SELECT l2_distance('[1,2]', '[3]'); +ERROR: different vector dimensions 2 and 1 +SELECT inner_product('[1,2]', '[3,4]'); + inner_product +--------------- + 11 +(1 row) + +SELECT inner_product('[1,2]', '[3]'); +ERROR: different vector dimensions 2 and 1 +SELECT round(cosine_distance('[1,2]', '[2,4]')::numeric, 5); + round +--------- + 0.00000 +(1 row) + +SELECT cosine_distance('[1,2]', '[0,0]'); + cosine_distance +----------------- + NaN +(1 row) + +SELECT cosine_distance('[1,2]', '[3]'); +ERROR: different vector dimensions 2 and 1 diff --git a/expected/ivfflat_cosine.out b/expected/ivfflat_cosine.out new file mode 100644 index 0000000..92662ff --- /dev/null +++ b/expected/ivfflat_cosine.out @@ -0,0 +1,21 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val vector_cosine_ops) WITH (lists = 1); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <=> '[3,3,3]'; + val +--------- + [1,1,1] + [1,2,3] + [1,2,4] +(3 rows) + +SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector); + val +----- +(0 rows) + +DROP TABLE t; diff --git a/expected/ivfflat_ip.out b/expected/ivfflat_ip.out new file mode 100644 index 0000000..af95308 --- /dev/null +++ b/expected/ivfflat_ip.out @@ -0,0 +1,22 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val vector_ip_ops) WITH (lists = 1); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <#> '[3,3,3]'; + val +--------- + [1,2,4] + [1,2,3] + [1,1,1] + [0,0,0] +(4 rows) + +SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector); + val +----- +(0 rows) + +DROP TABLE t; diff --git a/expected/ivfflat_l2.out b/expected/ivfflat_l2.out new file mode 100644 index 0000000..506cf48 --- /dev/null +++ b/expected/ivfflat_l2.out @@ -0,0 +1,22 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val) WITH (lists = 1); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] + [1,2,4] + [1,1,1] + [0,0,0] +(4 rows) + +SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); + val +----- +(0 rows) + +DROP TABLE t; diff --git a/expected/ivfflat_unlogged.out b/expected/ivfflat_unlogged.out new file mode 100644 index 0000000..59b6e49 --- /dev/null +++ b/expected/ivfflat_unlogged.out @@ -0,0 +1,15 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; +SET enable_seqscan = off; +CREATE UNLOGGED TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val) WITH (lists = 1); +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] + [1,1,1] + [0,0,0] +(3 rows) + +DROP TABLE t; diff --git a/expected/vector.out b/expected/vector.out new file mode 100644 index 0000000..11226f8 --- /dev/null +++ b/expected/vector.out @@ -0,0 +1,55 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; +SELECT '[1,2,3]'::vector; + vector +--------- + [1,2,3] +(1 row) + +SELECT '[-1,2,3]'::vector; + vector +---------- + [-1,2,3] +(1 row) + +SELECT '[hello,1]'::vector; +ERROR: invalid input syntax for type vector: "hello" +LINE 1: SELECT '[hello,1]'::vector; + ^ +SELECT '[NaN,1]'::vector; +ERROR: NaN not allowed in vector +LINE 1: SELECT '[NaN,1]'::vector; + ^ +SELECT '[Infinity,1]'::vector; +ERROR: infinite value not allowed in vector +LINE 1: SELECT '[Infinity,1]'::vector; + ^ +SELECT '[-Infinity,1]'::vector; +ERROR: infinite value not allowed in vector +LINE 1: SELECT '[-Infinity,1]'::vector; + ^ +SELECT '[1,2,3'::vector; +ERROR: malformed vector literal +LINE 1: SELECT '[1,2,3'::vector; + ^ +DETAIL: Unexpected end of input. +SELECT '[1,2,3]9'::vector; +ERROR: malformed vector literal +LINE 1: SELECT '[1,2,3]9'::vector; + ^ +DETAIL: Junk after closing right brace. +SELECT '1,2,3'::vector; +ERROR: malformed vector literal: "1,2,3" +LINE 1: SELECT '1,2,3'::vector; + ^ +DETAIL: Vector contents must start with "[". +SELECT '[]'::vector; +ERROR: vector must have at least 1 dimension +LINE 1: SELECT '[]'::vector; + ^ +SELECT '[1,]'::vector; +ERROR: invalid input syntax for type vector: "]" +LINE 1: SELECT '[1,]'::vector; + ^ +SELECT '[1,2,3]'::vector(2); +ERROR: expected 2 dimensions, not 3 diff --git a/ivfbuild.c b/ivfbuild.c new file mode 100644 index 0000000..228db3c --- /dev/null +++ b/ivfbuild.c @@ -0,0 +1,503 @@ +#include "postgres.h" + +#include + +#include "catalog/index.h" +#include "ivfflat.h" +#include "miscadmin.h" +#include "storage/bufmgr.h" + +#if PG_VERSION_NUM >= 120000 +#include "access/tableam.h" +#endif + +#if PG_VERSION_NUM >= 110000 +#include "catalog/pg_operator_d.h" +#include "catalog/pg_type_d.h" +#else +#include "catalog/pg_operator.h" +#include "catalog/pg_type.h" +#endif + +#if PG_VERSION_NUM >= 130000 +#define CALLBACK_ITEM_POINTER ItemPointer tid +#else +#define CALLBACK_ITEM_POINTER HeapTuple hup +#endif + +/* + * Callback for sampling + */ +static void +SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, + bool *isnull, bool tupleIsAlive, void *state) +{ + IvfflatBuildState *buildstate = (IvfflatBuildState *) state; + VectorArray samples = buildstate->samples; + int targsamples = samples->maxlen; + Datum value = values[0]; + + /* Skip nulls */ + if (isnull[0]) + return; + + /* Normalize the value */ + if (buildstate->normprocinfo != NULL) + { + if (!IvfflatNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->normvec)) + return; + } + + if (samples->length < targsamples) + { + VectorArraySet(samples, samples->length, DatumGetVector(value)); + samples->length++; + } + else + { + if (buildstate->rowstoskip < 0) + buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, samples->length, targsamples); + + if (buildstate->rowstoskip <= 0) + { + int k = (int) (targsamples * sampler_random_fract(buildstate->rstate.randstate)); + + Assert(k >= 0 && k < targsamples); + VectorArraySet(samples, k, DatumGetVector(value)); + } + + buildstate->rowstoskip -= 1; + } +} + +/* + * Sample rows with same logic as ANALYZE + */ +static void +SampleRows(IvfflatBuildState * buildstate) +{ + int targsamples = buildstate->samples->maxlen; + BlockNumber totalblocks = RelationGetNumberOfBlocks(buildstate->heap); + + buildstate->rowstoskip = -1; + + BlockSampler_Init(&buildstate->bs, totalblocks, targsamples, random()); + + reservoir_init_selection_state(&buildstate->rstate, targsamples); + while (BlockSampler_HasMore(&buildstate->bs)) + { + BlockNumber targblock = BlockSampler_Next(&buildstate->bs); + +#if PG_VERSION_NUM >= 120000 + table_index_build_range_scan(buildstate->heap, buildstate->index, buildstate->indexInfo, + false, true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL); +#elif PG_VERSION_NUM >= 110000 + IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL); +#else + IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, true, targblock, 1, SampleCallback, (void *) buildstate); +#endif + } +} + +/* + * Callback for table_index_build_scan + */ +static void +BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, + bool *isnull, bool tupleIsAlive, void *state) +{ + IvfflatBuildState *buildstate = (IvfflatBuildState *) state; + double distance; + double minDistance = DBL_MAX; + int closestCenter = -1; + VectorArray centers = buildstate->centers; + TupleTableSlot *slot = buildstate->slot; + Datum value = values[0]; + int i; + +#if PG_VERSION_NUM < 130000 + ItemPointer tid = &hup->t_self; +#endif + + if (isnull[0]) + return; + + /* Normalize if needed */ + if (buildstate->normprocinfo != NULL) + { + if (!IvfflatNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->normvec)) + return; + } + + /* Find the list that minimizes the distance */ + for (i = 0; i < centers->length; i++) + { + distance = DatumGetFloat8(FunctionCall2Coll(buildstate->procinfo, buildstate->collation, value, PointerGetDatum(VectorArrayGet(centers, i)))); + + if (distance < minDistance) + { + minDistance = distance; + closestCenter = i; + } + } + + /* Create a virtual tuple */ + ExecClearTuple(slot); + slot->tts_values[0] = Int32GetDatum(closestCenter); + slot->tts_isnull[0] = false; + slot->tts_values[1] = Int32GetDatum(ItemPointerGetBlockNumberNoCheck(tid)); + slot->tts_isnull[1] = false; + slot->tts_values[2] = Int32GetDatum(ItemPointerGetOffsetNumberNoCheck(tid)); + slot->tts_isnull[2] = false; + slot->tts_values[3] = value; + slot->tts_isnull[3] = false; + ExecStoreVirtualTuple(slot); + + /* + * Add tuple to sort + * + * tuplesort_puttupleslot comment: Input data is always copied; the caller + * need not save it. + */ + tuplesort_puttupleslot(buildstate->sortstate, slot); +} + +/* + * Get index tuple from sort state + */ +static inline void +GetNextTuple(Tuplesortstate *sortstate, TupleDesc tupdesc, TupleTableSlot *slot, IndexTuple *itup, int *list) +{ + Datum value; + bool isnull; + int tupblk; + int tupoff; + +#if PG_VERSION_NUM >= 100000 + if (tuplesort_gettupleslot(sortstate, true, false, slot, NULL)) +#else + if (tuplesort_gettupleslot(sortstate, true, slot, NULL)) +#endif + { + *list = DatumGetInt32(slot_getattr(slot, 1, &isnull)); + tupblk = DatumGetInt32(slot_getattr(slot, 2, &isnull)); + tupoff = DatumGetInt32(slot_getattr(slot, 3, &isnull)); + value = slot_getattr(slot, 4, &isnull); + + /* Form the index tuple */ + *itup = index_form_tuple(tupdesc, &value, &isnull); + ItemPointerSet(&(*itup)->t_tid, tupblk, tupoff); + } + else + *list = -1; +} + +/* + * Create initial entry pages + */ +static void +InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum) +{ + Buffer buf; + Page page; + GenericXLogState *state; + int list; + IndexTuple itup = NULL; /* silence compiler warning */ + BlockNumber startPage = InvalidBlockNumber; + BlockNumber insertPage = InvalidBlockNumber; + Size itemsz; + int i; + +#if PG_VERSION_NUM >= 120000 + TupleTableSlot *slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsMinimalTuple); +#else + TupleTableSlot *slot = MakeSingleTupleTableSlot(buildstate->tupdesc); +#endif + TupleDesc tupdesc = RelationGetDescr(index); + + GetNextTuple(buildstate->sortstate, tupdesc, slot, &itup, &list); + + for (i = 0; i < buildstate->centers->length; i++) + { + buf = IvfflatNewBuffer(index, forkNum); + IvfflatInitPage(index, &buf, &page, &state); + + startPage = BufferGetBlockNumber(buf); + + /* Get all tuples for list */ + while (list == i) + { + /* Check for free space */ + itemsz = MAXALIGN(IndexTupleSize(itup)); + if (PageGetFreeSpace(page) < itemsz) + IvfflatAppendPage(index, &buf, &page, &state, forkNum); + + /* Add the item */ + if (PageAddItem(page, (Item) itup, itemsz, InvalidOffsetNumber, false, false) == InvalidOffsetNumber) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + buildstate->indtuples += 1; + + GetNextTuple(buildstate->sortstate, tupdesc, slot, &itup, &list); + } + + insertPage = BufferGetBlockNumber(buf); + + IvfflatCommitBuffer(buf, state); + + /* Set the start and insert pages */ + IvfflatUpdateList(index, state, buildstate->listInfo[i], insertPage, startPage, forkNum); + } +} + +/* + * Initialize the build state + */ +static void +InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo) +{ + buildstate->heap = heap; + buildstate->index = index; + buildstate->indexInfo = indexInfo; + + buildstate->lists = IvfflatGetLists(index); + buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; + + /* Require column to have dimensions to be indexed */ + if (buildstate->dimensions < 0) + elog(ERROR, "column does not have dimensions"); + + buildstate->reltuples = 0; + buildstate->indtuples = 0; + + /* Get support functions */ + buildstate->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC); + buildstate->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); + buildstate->collation = index->rd_indcollation[0]; + + /* Create tuple description for sorting */ +#if PG_VERSION_NUM >= 120000 + buildstate->tupdesc = CreateTemplateTupleDesc(4); +#else + buildstate->tupdesc = CreateTemplateTupleDesc(4, false); +#endif + TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 1, "list", INT4OID, -1, 0); + TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 2, "blkno", INT4OID, -1, 0); + TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 3, "offset", INT4OID, -1, 0); +#if PG_VERSION_NUM >= 110000 + TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 4, "vector", RelationGetDescr(index)->attrs[0].atttypid, -1, 0); +#else + TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 4, "vector", RelationGetDescr(index)->attrs[0]->atttypid, -1, 0); +#endif + +#if PG_VERSION_NUM >= 120000 + buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual); +#else + buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc); +#endif + + buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions); + buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists); + + /* Reuse for each tuple */ + buildstate->normvec = InitVector(buildstate->dimensions); +} + +/* + * Free resources + */ +static void +FreeBuildState(IvfflatBuildState * buildstate) +{ + pfree(buildstate->centers); + pfree(buildstate->listInfo); + pfree(buildstate->normvec); +} + +/* + * Compute centers + */ +static void +ComputeCenters(IvfflatBuildState * buildstate) +{ + int numSamples; + + /* Target 50 samples per list, with at least 10000 samples */ + /* The number of samples has a large effect on index build time */ + numSamples = buildstate->lists * 50; + if (numSamples < 10000) + numSamples = 10000; + + /* Sample samples */ + buildstate->samples = VectorArrayInit(numSamples, buildstate->dimensions); + if (buildstate->heap != NULL) + SampleRows(buildstate); + + /* Calculate centers */ + IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers); + + /* Free samples before we allocate more memory */ + pfree(buildstate->samples); +} + +/* + * Create the metapage + */ +static void +CreateMetaPage(Relation index, int dimensions, int lists, ForkNumber forkNum) +{ + Buffer buf; + Page page; + GenericXLogState *state; + IvfflatMetaPage metap; + + buf = IvfflatNewBuffer(index, forkNum); + IvfflatInitPage(index, &buf, &page, &state); + + /* Set metapage data */ + metap = IvfflatPageGetMeta(page); + metap->magicNumber = IVFFLAT_MAGIC_NUMBER; + metap->version = IVFFLAT_VERSION; + metap->dimensions = dimensions; + metap->lists = lists; + ((PageHeader) page)->pd_lower = + ((char *) metap + sizeof(IvfflatMetaPageData)) - (char *) page; + + IvfflatCommitBuffer(buf, state); +} + +/* + * Create list pages + */ +static void +CreateListPages(Relation index, VectorArray centers, int dimensions, + int lists, ForkNumber forkNum, ListInfo * *listInfo) +{ + int i; + Buffer buf; + Page page; + GenericXLogState *state; + OffsetNumber offno; + Size itemsz; + IvfflatList list; + + itemsz = MAXALIGN(IVFFLAT_LIST_SIZE(dimensions)); + list = palloc(itemsz); + + buf = IvfflatNewBuffer(index, forkNum); + IvfflatInitPage(index, &buf, &page, &state); + + for (i = 0; i < lists; i++) + { + /* Load list */ + list->startPage = InvalidBlockNumber; + list->insertPage = InvalidBlockNumber; + memcpy(&list->center, VectorArrayGet(centers, i), VECTOR_SIZE(dimensions)); + + /* Ensure free space */ + if (PageGetFreeSpace(page) < itemsz) + IvfflatAppendPage(index, &buf, &page, &state, forkNum); + + /* Add the item */ + offno = PageAddItem(page, (Item) list, itemsz, InvalidOffsetNumber, false, false); + if (offno == InvalidOffsetNumber) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Save location info */ + (*listInfo)[i].blkno = BufferGetBlockNumber(buf); + (*listInfo)[i].offno = offno; + } + + IvfflatCommitBuffer(buf, state); + + pfree(list); +} + +/* + * Create entry pages + */ +static void +CreateEntryPages(IvfflatBuildState * buildstate, ForkNumber forkNum) +{ + AttrNumber attNums[] = {1}; + Oid sortOperators[] = {Float8LessOperator}; + Oid sortCollations[] = {InvalidOid}; + bool nullsFirstFlags[] = {false}; + +#if PG_VERSION_NUM >= 110000 + buildstate->sortstate = tuplesort_begin_heap(buildstate->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, maintenance_work_mem, NULL, false); +#else + buildstate->sortstate = tuplesort_begin_heap(buildstate->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, maintenance_work_mem, false); +#endif + + /* Add tuples to sort */ + if (buildstate->heap != NULL) + { +#if PG_VERSION_NUM >= 120000 + buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, true, BuildCallback, (void *) buildstate, NULL); +#elif PG_VERSION_NUM >= 110000 + buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, BuildCallback, (void *) buildstate, NULL); +#else + buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, BuildCallback, (void *) buildstate); +#endif + } + + /* Sort and insert */ + tuplesort_performsort(buildstate->sortstate); + InsertTuples(buildstate->index, buildstate, forkNum); + tuplesort_end(buildstate->sortstate); +} + +/* + * Build the index + */ +static void +BuildIndex(Relation heap, Relation index, IndexInfo *indexInfo, + IvfflatBuildState * buildstate, ForkNumber forkNum) +{ + InitBuildState(buildstate, heap, index, indexInfo); + + ComputeCenters(buildstate); + + /* Create pages */ + CreateMetaPage(index, buildstate->dimensions, buildstate->lists, forkNum); + CreateListPages(index, buildstate->centers, buildstate->dimensions, buildstate->lists, forkNum, &buildstate->listInfo); + CreateEntryPages(buildstate, forkNum); + + FreeBuildState(buildstate); +} + +/* + * Build the index for a logged table + */ +IndexBuildResult * +ivfflatbuild(Relation heap, Relation index, IndexInfo *indexInfo) +{ + IndexBuildResult *result; + IvfflatBuildState buildstate; + + BuildIndex(heap, index, indexInfo, &buildstate, MAIN_FORKNUM); + + result = (IndexBuildResult *) palloc(sizeof(IndexBuildResult)); + result->heap_tuples = buildstate.reltuples; + result->index_tuples = buildstate.indtuples; + + return result; +} + +/* + * Build the index for an unlogged table + */ +void +ivfflatbuildempty(Relation index) +{ + IndexInfo *indexInfo = BuildIndexInfo(index); + IvfflatBuildState buildstate; + + BuildIndex(NULL, index, indexInfo, &buildstate, INIT_FORKNUM); +} diff --git a/ivfflat.c b/ivfflat.c new file mode 100644 index 0000000..5073705 --- /dev/null +++ b/ivfflat.c @@ -0,0 +1,168 @@ +#include "postgres.h" + +#include "access/amapi.h" +#include "commands/vacuum.h" +#include "ivfflat.h" +#include "utils/guc.h" +#include "utils/selfuncs.h" + +static relopt_kind ivfflat_relopt_kind; + +/* + * Initialize index options and variables + */ +void +_PG_init(void) +{ + ivfflat_relopt_kind = add_reloption_kind(); + add_int_reloption(ivfflat_relopt_kind, "lists", "Number of inverted lists", + IVFFLAT_DEFAULT_LISTS, 1, IVFFLAT_MAX_LISTS +#if PG_VERSION_NUM >= 130000 + ,AccessExclusiveLock +#endif + ); + + DefineCustomIntVariable("ivfflat.probes", "Sets the number of probes", + "Valid range is 1..lists.", &ivfflat_probes, + 1, 1, IVFFLAT_MAX_LISTS, PGC_USERSET, 0, NULL, NULL, NULL); +} + +/* + * Estimate the cost of an index scan + * + * TODO Improve cost estimation + */ +static void +ivfflatcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, + Cost *indexStartupCost, Cost *indexTotalCost, + Selectivity *indexSelectivity, double *indexCorrelation +#if PG_VERSION_NUM >= 100000 + ,double *indexPages +#endif +) +{ + GenericCosts costs; + +#if PG_VERSION_NUM < 120000 + List *qinfos; + + qinfos = deconstruct_indexquals(path); +#endif + + MemSet(&costs, 0, sizeof(costs)); + +#if PG_VERSION_NUM >= 120000 + genericcostestimate(root, path, loop_count, &costs); +#else + genericcostestimate(root, path, loop_count, qinfos, &costs); +#endif + + *indexStartupCost = costs.indexStartupCost; + *indexTotalCost = costs.indexTotalCost; + *indexSelectivity = costs.indexSelectivity; + *indexCorrelation = costs.indexCorrelation; +#if PG_VERSION_NUM >= 100000 + *indexPages = costs.numIndexPages; +#endif +} + +/* + * Parse and validate the reloptions + */ +static bytea * +ivfflatoptions(Datum reloptions, bool validate) +{ + static const relopt_parse_elt tab[] = { + {"lists", RELOPT_TYPE_INT, offsetof(IvfflatOptions, lists)}, + }; + +#if PG_VERSION_NUM >= 130000 + return (bytea *) build_reloptions(reloptions, validate, + ivfflat_relopt_kind, + sizeof(IvfflatOptions), + tab, lengthof(tab)); +#else + relopt_value *options; + int numoptions; + IvfflatOptions *rdopts; + + options = parseRelOptions(reloptions, validate, ivfflat_relopt_kind, &numoptions); + rdopts = allocateReloptStruct(sizeof(IvfflatOptions), options, numoptions); + fillRelOptions((void *) rdopts, sizeof(IvfflatOptions), options, numoptions, + validate, tab, lengthof(tab)); + + return (bytea *) rdopts; +#endif +} + +/* + * Validate catalog entries for the specified operator class + */ +static bool +ivfflatvalidate(Oid opclassoid) +{ + return true; +} + +PG_FUNCTION_INFO_V1(ivfflathandler); +Datum +ivfflathandler(PG_FUNCTION_ARGS) +{ + IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); + + amroutine->amstrategies = 0; + amroutine->amsupport = 4; +#if PG_VERSION_NUM >= 130000 + amroutine->amoptsprocnum = 0; +#endif + amroutine->amcanorder = false; + amroutine->amcanorderbyop = true; + amroutine->amcanbackward = false; /* can change direction mid-scan */ + amroutine->amcanunique = false; + amroutine->amcanmulticol = false; + amroutine->amoptionalkey = true; + amroutine->amsearcharray = false; + amroutine->amsearchnulls = false; + amroutine->amstorage = false; + amroutine->amclusterable = false; + amroutine->ampredlocks = false; +#if PG_VERSION_NUM >= 100000 + amroutine->amcanparallel = false; +#endif +#if PG_VERSION_NUM >= 110000 + amroutine->amcaninclude = false; +#endif +#if PG_VERSION_NUM >= 130000 + amroutine->amusemaintenanceworkmem = false; /* not used during VACUUM */ + amroutine->amparallelvacuumoptions = VACUUM_OPTION_NO_PARALLEL; /* TODO support parallel */ +#endif + amroutine->amkeytype = InvalidOid; + + amroutine->ambuild = ivfflatbuild; + amroutine->ambuildempty = ivfflatbuildempty; + amroutine->aminsert = ivfflatinsert; + amroutine->ambulkdelete = ivfflatbulkdelete; + amroutine->amvacuumcleanup = ivfflatvacuumcleanup; + amroutine->amcanreturn = NULL; + amroutine->amcostestimate = ivfflatcostestimate; + amroutine->amoptions = ivfflatoptions; + amroutine->amproperty = NULL; /* TODO AMPROP_DISTANCE_ORDERABLE */ +#if PG_VERSION_NUM >= 120000 + amroutine->ambuildphasename = NULL; +#endif + amroutine->amvalidate = ivfflatvalidate; + amroutine->ambeginscan = ivfflatbeginscan; + amroutine->amrescan = ivfflatrescan; + amroutine->amgettuple = ivfflatgettuple; + amroutine->amgetbitmap = NULL; + amroutine->amendscan = ivfflatendscan; + amroutine->ammarkpos = NULL; + amroutine->amrestrpos = NULL; +#if PG_VERSION_NUM >= 100000 + amroutine->amestimateparallelscan = NULL; + amroutine->aminitparallelscan = NULL; + amroutine->amparallelrescan = NULL; +#endif + + PG_RETURN_POINTER(amroutine); +} diff --git a/ivfflat.h b/ivfflat.h new file mode 100644 index 0000000..8d28b81 --- /dev/null +++ b/ivfflat.h @@ -0,0 +1,193 @@ +#ifndef IVFFLAT_H +#define IVFFLAT_H + +#include "postgres.h" + +#include "access/generic_xlog.h" +#include "access/reloptions.h" +#include "nodes/execnodes.h" +#include "utils/sampling.h" +#include "utils/tuplesort.h" +#include "vector.h" + +/* Support functions */ +#define IVFFLAT_DISTANCE_PROC 1 +#define IVFFLAT_NORM_PROC 2 +#define IVFFLAT_KMEANS_DISTANCE_PROC 3 +#define IVFFLAT_KMEANS_NORM_PROC 4 + +#define IVFFLAT_VERSION 1 +#define IVFFLAT_MAGIC_NUMBER 0x14FF1A7 +#define IVFFLAT_PAGE_ID 0xFF84 + +/* Preserved page numbers */ +#define IVFFLAT_METAPAGE_BLKNO 0 +#define IVFFLAT_HEAD_BLKNO 1 /* first list page */ + +#define IVFFLAT_DEFAULT_LISTS 100 +#define IVFFLAT_MAX_LISTS 32768 + +#define IVFFLAT_LIST_SIZE(_dim) (offsetof(IvfflatListData, center) + VECTOR_SIZE(_dim)) + +#define IvfflatPageGetOpaque(page) ((IvfflatPageOpaque) PageGetSpecialPointer(page)) +#define IvfflatPageGetMeta(page) ((IvfflatMetaPageData *) PageGetContents(page)) + +#if PG_VERSION_NUM < 100000 +#define ItemPointerGetBlockNumberNoCheck ItemPointerGetBlockNumber +#define ItemPointerGetOffsetNumberNoCheck ItemPointerGetOffsetNumber +#endif + +/* Variables */ +int ivfflat_probes; + +typedef struct VectorArrayData +{ + int length; + int maxlen; + int dim; + Vector items[FLEXIBLE_ARRAY_MEMBER]; +} VectorArrayData; + +typedef VectorArrayData * VectorArray; + +typedef struct ListInfo +{ + BlockNumber blkno; + OffsetNumber offno; +} ListInfo; + +/* IVFFlat index options */ +typedef struct IvfflatOptions +{ + int32 vl_len_; /* varlena header (do not touch directly!) */ + int lists; /* number of lists */ +} IvfflatOptions; + +typedef struct IvfflatBuildState +{ + /* Info */ + Relation heap; + Relation index; + IndexInfo *indexInfo; + + /* Settings */ + int dimensions; + int lists; + + /* Statistics */ + double indtuples; + double reltuples; + + /* Support functions */ + FmgrInfo *procinfo; + FmgrInfo *normprocinfo; + Oid collation; + + /* Variables */ + VectorArray samples; + VectorArray centers; + ListInfo *listInfo; + Vector *normvec; + + /* Sampling */ + BlockSamplerData bs; + ReservoirStateData rstate; + int rowstoskip; + + /* Sorting */ + Tuplesortstate *sortstate; + TupleDesc tupdesc; + TupleTableSlot *slot; +} IvfflatBuildState; + +typedef struct IvfflatMetaPageData +{ + uint32 magicNumber; + uint32 version; + uint16 dimensions; + uint16 lists; +} IvfflatMetaPageData; + +typedef IvfflatMetaPageData * IvfflatMetaPage; + +typedef struct IvfflatPageOpaqueData +{ + BlockNumber nextblkno; + uint16 unused; + uint16 page_id; /* for identification of IVFFlat indexes */ +} IvfflatPageOpaqueData; + +typedef IvfflatPageOpaqueData * IvfflatPageOpaque; + +typedef struct IvfflatListData +{ + BlockNumber startPage; + BlockNumber insertPage; + Vector center; +} IvfflatListData; + +typedef IvfflatListData * IvfflatList; + +typedef struct IvfflatScanList +{ + BlockNumber startPage; + double distance; +} IvfflatScanList; + +typedef struct IvfflatScanOpaqueData +{ + int probes; + bool first; + Buffer buf; + + /* Sorting */ + Tuplesortstate *sortstate; + TupleDesc tupdesc; + TupleTableSlot *slot; + bool isnull; + + /* Support functions */ + FmgrInfo *procinfo; + FmgrInfo *normprocinfo; + Oid collation; + + IvfflatScanList lists[FLEXIBLE_ARRAY_MEMBER]; /* must come last */ +} IvfflatScanOpaqueData; + +typedef IvfflatScanOpaqueData * IvfflatScanOpaque; + +#define VECTOR_ARRAY_SIZE(_length, _dim) (offsetof(VectorArrayData, items) + _length * VECTOR_SIZE(_dim)) +#define VECTOR_ARRAY_OFFSET(_arr, _offset) ((char*) _arr + offsetof(VectorArrayData, items) + (_offset) * VECTOR_SIZE(_arr->dim)) +#define VectorArrayGet(_arr, _offset) ((Vector *) VECTOR_ARRAY_OFFSET(_arr, _offset)) +#define VectorArraySet(_arr, _offset, _val) (memcpy(VECTOR_ARRAY_OFFSET(_arr, _offset), _val, VECTOR_SIZE(_arr->dim))) + +/* Methods */ +void _PG_init(void); +VectorArray VectorArrayInit(int maxlen, int dimensions); +void PrintVectorArray(char *msg, VectorArray arr); +void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers); +FmgrInfo *IvfflatOptionalProcInfo(Relation rel, uint16 procnum); +bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result); +int IvfflatGetLists(Relation index); +void IvfflatUpdateList(Relation index, GenericXLogState *state, ListInfo listInfo, BlockNumber insertPage, BlockNumber startPage, ForkNumber forkNum); +void IvfflatCommitBuffer(Buffer buf, GenericXLogState *state); +void IvfflatAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, ForkNumber forkNum); +Buffer IvfflatNewBuffer(Relation index, ForkNumber forkNum); +void IvfflatInitPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state); + +/* Index access methods */ +IndexBuildResult *ivfflatbuild(Relation heap, Relation index, IndexInfo *indexInfo); +void ivfflatbuildempty(Relation index); +bool ivfflatinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heap, IndexUniqueCheck checkUnique +#if PG_VERSION_NUM >= 100000 + ,IndexInfo *indexInfo +#endif +); +IndexBulkDeleteResult *ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state); +IndexBulkDeleteResult *ivfflatvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats); +IndexScanDesc ivfflatbeginscan(Relation index, int nkeys, int norderbys); +void ivfflatrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys); +bool ivfflatgettuple(IndexScanDesc scan, ScanDirection dir); +void ivfflatendscan(IndexScanDesc scan); + +#endif diff --git a/ivfinsert.c b/ivfinsert.c new file mode 100644 index 0000000..fa10f0f --- /dev/null +++ b/ivfinsert.c @@ -0,0 +1,163 @@ +#include "postgres.h" + +#include + +#include "ivfflat.h" +#include "storage/bufmgr.h" + +/* + * Find the list that minimizes the distance function + */ +static void +FindInsertPage(Relation rel, Datum *values, BlockNumber *insertPage, ListInfo * listInfo) +{ + Buffer cbuf; + Page cpage; + IvfflatList list; + double distance; + double minDistance = DBL_MAX; + BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO; + FmgrInfo *procinfo; + Oid collation; + OffsetNumber offno; + OffsetNumber maxoffno; + + procinfo = index_getprocinfo(rel, 1, IVFFLAT_DISTANCE_PROC); + collation = rel->rd_indcollation[0]; + + /* Search all list pages */ + while (BlockNumberIsValid(nextblkno)) + { + cbuf = ReadBuffer(rel, nextblkno); + LockBuffer(cbuf, BUFFER_LOCK_SHARE); + cpage = BufferGetPage(cbuf); + maxoffno = PageGetMaxOffsetNumber(cpage); + + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, offno)); + distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, values[0], PointerGetDatum(&list->center))); + + if (distance < minDistance) + { + *insertPage = list->insertPage; + listInfo->blkno = nextblkno; + listInfo->offno = offno; + minDistance = distance; + } + } + + nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno; + + UnlockReleaseBuffer(cbuf); + } +} + +/* + * Prepare to insert an index tuple + */ +static void +LoadInsertPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, BlockNumber insertPage) +{ + *buf = ReadBuffer(index, insertPage); + LockBuffer(*buf, BUFFER_LOCK_EXCLUSIVE); + *state = GenericXLogStart(index); + *page = GenericXLogRegisterBuffer(*state, *buf, 0); +} + +/* + * Insert a tuple into the index + */ +static void +InsertTuple(Relation rel, IndexTuple itup, Relation heapRel, Datum *values) +{ + Buffer buf; + Page page; + GenericXLogState *state; + Size itemsz; + BlockNumber insertPage = InvalidBlockNumber; + ListInfo listInfo; + bool newPage = false; + + /* Find the insert page - sets the page and list info */ + FindInsertPage(rel, values, &insertPage, &listInfo); + Assert(BlockNumberIsValid(insertPage)); + + itemsz = MAXALIGN(IndexTupleSize(itup)); + Assert(itemsz <= BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(IvfflatPageOpaqueData))); + + LoadInsertPage(rel, &buf, &page, &state, insertPage); + + /* Find a page to insert the item */ + while (PageGetFreeSpace(page) < itemsz) + { + insertPage = IvfflatPageGetOpaque(page)->nextblkno; + + if (BlockNumberIsValid(insertPage)) + { + /* Move to next page */ + GenericXLogAbort(state); + UnlockReleaseBuffer(buf); + + LoadInsertPage(rel, &buf, &page, &state, insertPage); + } + else + { + /* Add a new page */ + IvfflatAppendPage(rel, &buf, &page, &state, MAIN_FORKNUM); + + insertPage = BufferGetBlockNumber(buf); + newPage = true; + } + } + + /* Add to next offset */ + if (PageAddItem(page, (Item) itup, itemsz, InvalidOffsetNumber, false, false) == InvalidOffsetNumber) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(rel)); + + IvfflatCommitBuffer(buf, state); + + /* Update the insert page */ + if (newPage) + IvfflatUpdateList(rel, state, listInfo, insertPage, InvalidBlockNumber, MAIN_FORKNUM); +} + +/* + * Insert a tuple into the index + */ +bool +ivfflatinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, + Relation heap, IndexUniqueCheck checkUnique +#if PG_VERSION_NUM >= 100000 + ,IndexInfo *indexInfo +#endif +) +{ + IndexTuple itup; + Datum value; + FmgrInfo *normprocinfo; + + if (isnull[0]) + return false; + + value = values[0]; + + /* Normalize if needed */ + normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); + if (normprocinfo != NULL) + { + if (!IvfflatNormValue(normprocinfo, index->rd_indcollation[0], &value, NULL)) + return false; + } + + itup = index_form_tuple(RelationGetDescr(index), &value, isnull); + itup->t_tid = *heap_tid; + InsertTuple(index, itup, heap, &value); + pfree(itup); + + /* Clean up if we allocated a new value */ + if (value != values[0]) + pfree(DatumGetPointer(value)); + + return false; +} diff --git a/ivfkmeans.c b/ivfkmeans.c new file mode 100644 index 0000000..61dd917 --- /dev/null +++ b/ivfkmeans.c @@ -0,0 +1,479 @@ +#include "postgres.h" + +#include + +#include "ivfflat.h" +#include "miscadmin.h" + +/* + * Initialize with kmeans++ + * + * https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf + */ +static void +InitCenters(Relation index, VectorArray samples, VectorArray centers, double *lowerBound) +{ + FmgrInfo *procinfo; + Oid collation; + int i; + int j; + double distance; + double sum; + double choice; + Vector *vec; + double *weight = palloc(samples->length * sizeof(double)); + int numCenters = centers->maxlen; + int numSamples = samples->length; + + procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC); + collation = index->rd_indcollation[0]; + + /* Choose an initial center uniformly at random */ + VectorArraySet(centers, 0, VectorArrayGet(samples, random() % samples->length)); + centers->length++; + + for (j = 0; j < numSamples; j++) + weight[j] = DBL_MAX; + + for (i = 0; i < numCenters; i++) + { + CHECK_FOR_INTERRUPTS(); + + sum = 0.0; + + for (j = 0; j < numSamples; j++) + { + vec = VectorArrayGet(samples, j); + + /* Only need to compute distance for new center */ + /* TODO Use triangle inequality to reduce distance calculations */ + distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, i)))); + + /* Set lower bound */ + lowerBound[j * numCenters + i] = distance; + + /* Use distance squared for weighted probability distribution */ + distance *= distance; + + if (distance < weight[j]) + weight[j] = distance; + + sum += weight[j]; + } + + /* Only compute lower bound on last iteration */ + if (i + 1 == numCenters) + break; + + /* Choose new center using weighted probability distribution. */ + choice = sum * (((double) random()) / MAX_RANDOM_VALUE); + for (j = 0; j < numSamples - 1; j++) + { + choice -= weight[j]; + if (choice <= 0) + break; + } + + VectorArraySet(centers, i + 1, VectorArrayGet(samples, j)); + centers->length++; + } + + pfree(weight); +} + +/* + * Apply norm to vector + */ +static inline void +ApplyNorm(FmgrInfo *normprocinfo, Oid collation, Vector * vec) +{ + int i; + double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(vec))); + + /* TODO Handle zero norm */ + if (norm > 0) + { + for (i = 0; i < vec->dim; i++) + vec->x[i] /= norm; + } +} + +/* + * Compare vectors + */ +static int +CompareVectors(const void *a, const void *b) +{ + return vector_cmp_internal((Vector *) a, (Vector *) b); +} + +/* + * Quick approach if we have little data + */ +static void +QuickCenters(Relation index, VectorArray samples, VectorArray centers) +{ + int i; + int j; + Vector *vec; + int dimensions = centers->dim; + Oid collation = index->rd_indcollation[0]; + FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); + + /* Copy existing vectors while avoiding duplicates */ + qsort(samples->items, samples->length, VECTOR_SIZE(samples->dim), CompareVectors); + for (i = 0; i < samples->length; i++) + { + vec = VectorArrayGet(samples, i); + + if (i == 0 || CompareVectors(vec, VectorArrayGet(samples, i - 1)) != 0) + { + VectorArraySet(centers, centers->length, vec); + centers->length++; + } + } + + /* Fill remaining with random data */ + while (centers->length < centers->maxlen) + { + vec = VectorArrayGet(centers, centers->length); + + SET_VARSIZE(vec, VECTOR_SIZE(dimensions)); + vec->dim = dimensions; + + for (j = 0; j < dimensions; j++) + vec->x[j] = ((double) random()) / MAX_RANDOM_VALUE; + + /* Normalize if needed (only needed for random centers) */ + if (normprocinfo != NULL) + ApplyNorm(normprocinfo, collation, vec); + + centers->length++; + } +} + +/* + * Use Elkan for performance. This requires distance function to satisfy triangle inequality. + * + * We use L2 distance for L2 (not L2 squared like index scan) + * and angular distance for inner product and cosine distance + * + * https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf + */ +static void +ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) +{ + FmgrInfo *procinfo; + FmgrInfo *normprocinfo; + Oid collation; + Vector *vec; + Vector *newCenter; + int iteration; + int j; + int k; + int dimensions = centers->dim; + int numCenters = centers->maxlen; + int numSamples = samples->length; + VectorArray newCenters; + int *centerCounts; + int *closestCenters; + double *lowerBound; + double *upperBound; + double *s; + double *halfcdist; + double *newcdist; + int changes; + double minDistance; + int closestCenter; + double distance; + bool rj; + bool rjreset; + double dxcx; + double dxc; + + /* Set support functions */ + procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC); + normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); + collation = index->rd_indcollation[0]; + + /* Allocate space */ + centerCounts = palloc(sizeof(int) * numCenters); + closestCenters = palloc(sizeof(int) * numSamples); + lowerBound = palloc(sizeof(double) * numSamples * numCenters); + upperBound = palloc(sizeof(double) * numSamples); + s = palloc(sizeof(double) * numCenters); + halfcdist = palloc(sizeof(double) * numCenters * numCenters); + newcdist = palloc(sizeof(double) * numCenters); + + newCenters = VectorArrayInit(numCenters, dimensions); + for (j = 0; j < numCenters; j++) + { + vec = VectorArrayGet(newCenters, j); + SET_VARSIZE(vec, VECTOR_SIZE(dimensions)); + vec->dim = dimensions; + } + + /* Pick initial centers */ + InitCenters(index, samples, centers, lowerBound); + + /* Assign each x to its closest initial center c(x) = argmin d(x,c) */ + for (j = 0; j < numSamples; j++) + { + minDistance = DBL_MAX; + closestCenter = -1; + + vec = VectorArrayGet(samples, j); + + /* Find closest center */ + for (k = 0; k < numCenters; k++) + { + /* TODO Use Lemma 1 in k-means++ initialization */ + distance = lowerBound[j * numCenters + k]; + + if (distance < minDistance) + { + minDistance = distance; + closestCenter = k; + } + } + + upperBound[j] = minDistance; + closestCenters[j] = closestCenter; + } + + /* Give 500 iterations to converge */ + for (iteration = 0; iteration < 500; iteration++) + { + /* Can take a while, so ensure we can interrupt */ + CHECK_FOR_INTERRUPTS(); + + changes = 0; + + /* Step 1: For all centers, compute distance */ + for (j = 0; j < numCenters; j++) + { + vec = VectorArrayGet(centers, j); + + for (k = j + 1; k < numCenters; k++) + { + distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k)))); + halfcdist[j * numCenters + k] = distance; + halfcdist[k * numCenters + j] = distance; + } + } + + /* For all centers c, compute s(c) */ + for (j = 0; j < numCenters; j++) + { + minDistance = DBL_MAX; + + for (k = 0; k < numCenters; k++) + { + if (j == k) + continue; + + distance = halfcdist[j * numCenters + k]; + if (distance < minDistance) + minDistance = distance; + } + + s[j] = minDistance; + } + + rjreset = iteration != 0; + + for (j = 0; j < numSamples; j++) + { + /* Step 2: Identify all points x such that u(x) <= s(c(x)) */ + if (upperBound[j] <= s[closestCenters[j]]) + continue; + + rj = rjreset; + + for (k = 0; k < numCenters; k++) + { + /* Step 3: For all remaining points x and centers c */ + if (k == closestCenters[j]) + continue; + + if (upperBound[j] <= lowerBound[j * numCenters + k]) + continue; + + if (upperBound[j] <= halfcdist[closestCenters[j] * numCenters + k]) + continue; + + vec = VectorArrayGet(samples, j); + + /* Step 3a */ + if (rj) + { + dxcx = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, closestCenters[j])))); + + /* d(x,c(x)) computed, which is a form of d(x,c) */ + lowerBound[j * numCenters + closestCenters[j]] = dxcx; + upperBound[j] = dxcx; + + rj = false; + } + else + dxcx = upperBound[j]; + + /* Step 3b */ + if (dxcx > lowerBound[j * numCenters + k] || dxcx > halfcdist[closestCenters[j] * numCenters + k]) + { + dxc = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k)))); + + /* d(x,c) calculated */ + lowerBound[j * numCenters + k] = dxc; + + if (dxc < dxcx) + { + closestCenters[j] = k; + + /* c(x) changed */ + upperBound[j] = dxc; + + changes++; + } + + } + } + } + + /* Step 4: For each center c, let m(c) be mean of all points assigned */ + for (j = 0; j < numCenters; j++) + { + vec = VectorArrayGet(newCenters, j); + for (k = 0; k < dimensions; k++) + vec->x[k] = 0.0; + + centerCounts[j] = 0; + } + + for (j = 0; j < numSamples; j++) + { + vec = VectorArrayGet(samples, j); + closestCenter = closestCenters[j]; + + /* Increment sum and count of closest center */ + newCenter = VectorArrayGet(newCenters, closestCenter); + for (k = 0; k < dimensions; k++) + newCenter->x[k] += vec->x[k]; + + centerCounts[closestCenter] += 1; + } + + for (j = 0; j < numCenters; j++) + { + vec = VectorArrayGet(newCenters, j); + + if (centerCounts[j] > 0) + { + for (k = 0; k < dimensions; k++) + vec->x[k] /= centerCounts[j]; + } + else + { + /* TODO Handle empty centers properly */ + for (k = 0; k < dimensions; k++) + vec->x[k] = ((double) random()) / MAX_RANDOM_VALUE; + } + + /* Normalize if needed */ + if (normprocinfo != NULL) + ApplyNorm(normprocinfo, collation, vec); + } + + /* Step 5 */ + for (j = 0; j < numCenters; j++) + newcdist[j] = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(VectorArrayGet(centers, j)), PointerGetDatum(VectorArrayGet(newCenters, j)))); + + for (j = 0; j < numSamples; j++) + { + for (k = 0; k < numCenters; k++) + { + distance = lowerBound[j * numCenters + k] - newcdist[k]; + + if (distance < 0) + distance = 0; + + lowerBound[j * numCenters + k] = distance; + } + } + + /* Step 6 */ + /* We reset r(x) before Step 3 in the next iteration */ + for (j = 0; j < numSamples; j++) + upperBound[j] += newcdist[closestCenters[j]]; + + /* Step 7 */ + for (j = 0; j < numCenters; j++) + memcpy(VectorArrayGet(centers, j), VectorArrayGet(newCenters, j), VECTOR_SIZE(dimensions)); + + if (changes == 0 && iteration != 0) + break; + } + + pfree(newCenters); + pfree(centerCounts); + pfree(closestCenters); + pfree(lowerBound); + pfree(upperBound); + pfree(s); + pfree(halfcdist); + pfree(newcdist); +} + +/* + * Detect issues with centers + */ +static void +CheckCenters(Relation index, VectorArray centers) +{ + FmgrInfo *normprocinfo; + Oid collation; + int i; + double norm; + + if (centers->length != centers->maxlen) + elog(ERROR, "Not enough centers. Please report a bug."); + + /* Ensure no duplicate centers */ + /* Fine to sort in-place */ + qsort(centers->items, centers->length, VECTOR_SIZE(centers->dim), CompareVectors); + for (i = 1; i < centers->length; i++) + { + if (CompareVectors(VectorArrayGet(centers, i), VectorArrayGet(centers, i - 1)) == 0) + elog(ERROR, "Duplicate centers detected. Please report a bug."); + } + + /* Ensure no zero vectors for cosine distance */ + /* Check NORM_PROC instead of KMEANS_NORM_PROC */ + normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); + if (normprocinfo != NULL) + { + collation = index->rd_indcollation[0]; + + for (i = 0; i < centers->length; i++) + { + norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(VectorArrayGet(centers, i)))); + if (norm == 0) + elog(ERROR, "Zero norm detected. Please report a bug."); + } + } +} + +/* + * Perform naive k-means centering + * We use spherical k-means for inner product and cosine + */ +void +IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers) +{ + if (samples->length <= centers->maxlen) + QuickCenters(index, samples, centers); + else + ElkanKmeans(index, samples, centers); + + CheckCenters(index, centers); +} diff --git a/ivfscan.c b/ivfscan.c new file mode 100644 index 0000000..99dc0f3 --- /dev/null +++ b/ivfscan.c @@ -0,0 +1,327 @@ +#include "postgres.h" + +#include "access/relscan.h" +#include "ivfflat.h" +#include "miscadmin.h" +#include "storage/bufmgr.h" + +#if PG_VERSION_NUM >= 110000 +#include "catalog/pg_operator_d.h" +#include "catalog/pg_type_d.h" +#else +#include "catalog/pg_operator.h" +#include "catalog/pg_type.h" +#endif + +/* + * Compare list distances + */ +static int +CompareLists(const void *a, const void *b) +{ + double diff = (((IvfflatScanList *) a)->distance - ((IvfflatScanList *) b)->distance); + + if (diff > 0) + return 1; + + if (diff < 0) + return -1; + + return 0; +} + +/* + * Get lists and sort by distance + */ +static void +GetScanLists(IndexScanDesc scan, Datum value) +{ + Buffer cbuf; + Page cpage; + IvfflatList list; + OffsetNumber offno; + OffsetNumber maxoffno; + BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO; + int listCount = 0; + IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; + double distance; + + /* Search all list pages */ + while (BlockNumberIsValid(nextblkno)) + { + cbuf = ReadBuffer(scan->indexRelation, nextblkno); + LockBuffer(cbuf, BUFFER_LOCK_SHARE); + cpage = BufferGetPage(cbuf); + + maxoffno = PageGetMaxOffsetNumber(cpage); + + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, offno)); + + /* Use procinfo from the index instead of scan key for performance */ + distance = DatumGetFloat8(FunctionCall2Coll(so->procinfo, so->collation, PointerGetDatum(&list->center), value)); + + so->lists[listCount].startPage = list->startPage; + so->lists[listCount].distance = distance; + listCount++; + } + + nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno; + + UnlockReleaseBuffer(cbuf); + } + + /* Sort by distance */ + qsort(so->lists, listCount, sizeof(IvfflatScanList), CompareLists); + + if (so->probes > listCount) + so->probes = listCount; +} + +/* + * Get items + */ +static void +GetScanItems(IndexScanDesc scan, Datum value) +{ + IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; + Buffer buf; + Page page; + IndexTuple itup; + BlockNumber searchPage; + OffsetNumber offno; + OffsetNumber maxoffno; + Datum datum; + bool isnull; + int i; + TupleDesc tupdesc = RelationGetDescr(scan->indexRelation); + +#if PG_VERSION_NUM >= 120000 + TupleTableSlot *slot = MakeSingleTupleTableSlot(so->tupdesc, &TTSOpsVirtual); +#else + TupleTableSlot *slot = MakeSingleTupleTableSlot(so->tupdesc); +#endif + + /* + * Reuse same set of shared buffers for scan + * + * See postgres/src/backend/storage/buffer/README for description + */ + BufferAccessStrategy bas = GetAccessStrategy(BAS_BULKREAD); + + /* Search closest probes lists */ + for (i = 0; i < so->probes; i++) + { + searchPage = so->lists[i].startPage; + + /* Search all entry pages for list */ + while (BlockNumberIsValid(searchPage)) + { + buf = ReadBufferExtended(scan->indexRelation, MAIN_FORKNUM, searchPage, RBM_NORMAL, bas); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + maxoffno = PageGetMaxOffsetNumber(page); + + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + itup = (IndexTuple) PageGetItem(page, PageGetItemId(page, offno)); + datum = index_getattr(itup, 1, tupdesc, &isnull); + + /* + * Add virtual tuple + * + * Use procinfo from the index instead of scan key for + * performance + */ + ExecClearTuple(slot); + slot->tts_values[0] = FunctionCall2Coll(so->procinfo, so->collation, datum, value); + slot->tts_isnull[0] = false; + slot->tts_values[1] = Int32GetDatum((int) ItemPointerGetBlockNumberNoCheck(&itup->t_tid)); + slot->tts_isnull[1] = false; + slot->tts_values[2] = Int32GetDatum((int) ItemPointerGetOffsetNumberNoCheck(&itup->t_tid)); + slot->tts_isnull[2] = false; + slot->tts_values[3] = Int32GetDatum((int) searchPage); + slot->tts_isnull[3] = false; + ExecStoreVirtualTuple(slot); + + tuplesort_puttupleslot(so->sortstate, slot); + } + + searchPage = IvfflatPageGetOpaque(page)->nextblkno; + + UnlockReleaseBuffer(buf); + } + } +} + +/* + * Prepare for an index scan + */ +IndexScanDesc +ivfflatbeginscan(Relation index, int nkeys, int norderbys) +{ + IndexScanDesc scan; + IvfflatScanOpaque so; + int lists; + AttrNumber attNums[] = {1}; + Oid sortOperators[] = {Float8LessOperator}; + Oid sortCollations[] = {InvalidOid}; + bool nullsFirstFlags[] = {false}; + + scan = RelationGetIndexScan(index, nkeys, norderbys); + lists = IvfflatGetLists(scan->indexRelation); + + so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + lists * sizeof(IvfflatScanList)); + so->buf = InvalidBuffer; + so->first = true; + + /* Set support functions */ + so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC); + so->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); + so->collation = index->rd_indcollation[0]; + + /* Create tuple description for sorting */ +#if PG_VERSION_NUM >= 120000 + so->tupdesc = CreateTemplateTupleDesc(4); +#else + so->tupdesc = CreateTemplateTupleDesc(4, false); +#endif + TupleDescInitEntry(so->tupdesc, (AttrNumber) 1, "distance", FLOAT8OID, -1, 0); + TupleDescInitEntry(so->tupdesc, (AttrNumber) 2, "blkno", INT4OID, -1, 0); + TupleDescInitEntry(so->tupdesc, (AttrNumber) 3, "offset", INT4OID, -1, 0); + TupleDescInitEntry(so->tupdesc, (AttrNumber) 4, "indexblkno", INT4OID, -1, 0); + + /* Prep sort */ +#if PG_VERSION_NUM >= 110000 + so->sortstate = tuplesort_begin_heap(so->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, work_mem, NULL, false); +#else + so->sortstate = tuplesort_begin_heap(so->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, work_mem, false); +#endif + +#if PG_VERSION_NUM >= 120000 + so->slot = MakeSingleTupleTableSlot(so->tupdesc, &TTSOpsMinimalTuple); +#else + so->slot = MakeSingleTupleTableSlot(so->tupdesc); +#endif + + scan->opaque = so; + + return scan; +} + +/* + * Start or restart an index scan + */ +void +ivfflatrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys) +{ + IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; + +#if PG_VERSION_NUM >= 130000 + if (!so->first) + tuplesort_reset(so->sortstate); +#endif + + so->first = true; + so->probes = ivfflat_probes; + + if (keys && scan->numberOfKeys > 0) + memmove(scan->keyData, keys, scan->numberOfKeys * sizeof(ScanKeyData)); + + if (orderbys && scan->numberOfOrderBys > 0) + memmove(scan->orderByData, orderbys, scan->numberOfOrderBys * sizeof(ScanKeyData)); +} + +/* + * Fetch the next tuple in the given scan + */ +bool +ivfflatgettuple(IndexScanDesc scan, ScanDirection dir) +{ + IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; + + /* + * Index can be used to scan backward, but Postgres doesn't support + * backward scan on operators + */ + Assert(ScanDirectionIsForward(dir)); + + if (so->first) + { + Datum value; + + /* No items will match if null */ + if (scan->orderByData->sk_flags & SK_ISNULL) + return false; + + value = scan->orderByData->sk_argument; + + if (so->normprocinfo != NULL) + { + /* No items will match if normalization fails */ + if (!IvfflatNormValue(so->normprocinfo, so->collation, &value, NULL)) + return false; + } + + GetScanLists(scan, value); + GetScanItems(scan, value); + tuplesort_performsort(so->sortstate); + so->first = false; + + /* Clean up if we allocated a new value */ + if (value != scan->orderByData->sk_argument) + pfree(DatumGetPointer(value)); + } + +#if PG_VERSION_NUM >= 100000 + if (tuplesort_gettupleslot(so->sortstate, true, false, so->slot, NULL)) +#else + if (tuplesort_gettupleslot(so->sortstate, true, so->slot, NULL)) +#endif + { + BlockNumber blkno = DatumGetInt32(slot_getattr(so->slot, 2, &so->isnull)); + OffsetNumber offset = DatumGetInt32(slot_getattr(so->slot, 3, &so->isnull)); + BlockNumber indexblkno = DatumGetInt32(slot_getattr(so->slot, 4, &so->isnull)); + +#if PG_VERSION_NUM >= 120000 + ItemPointerSet(&scan->xs_heaptid, blkno, offset); +#else + ItemPointerSet(&scan->xs_ctup.t_self, blkno, offset); +#endif + + if (BufferIsValid(so->buf)) + ReleaseBuffer(so->buf); + + /* + * An index scan must maintain a pin on the index page holding the + * item last returned by amgettuple + * + * https://www.postgresql.org/docs/current/index-locking.html + */ + so->buf = ReadBuffer(scan->indexRelation, indexblkno); + + scan->xs_recheckorderby = false; + return true; + } + + return false; +} + +/* + * End a scan and release resources + */ +void +ivfflatendscan(IndexScanDesc scan) +{ + IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; + + /* Release pin */ + if (BufferIsValid(so->buf)) + ReleaseBuffer(so->buf); + + tuplesort_end(so->sortstate); + + pfree(so); + scan->opaque = NULL; +} diff --git a/ivfutils.c b/ivfutils.c new file mode 100644 index 0000000..f3a2cf9 --- /dev/null +++ b/ivfutils.c @@ -0,0 +1,176 @@ +#include "postgres.h" + +#include "ivfflat.h" +#include "storage/bufmgr.h" +#include "vector.h" + +/* + * Allocate a vector array + */ +VectorArray +VectorArrayInit(int maxlen, int dimensions) +{ + VectorArray res = palloc0(VECTOR_ARRAY_SIZE(maxlen, dimensions)); + + res->length = 0; + res->maxlen = maxlen; + res->dim = dimensions; + return res; +} + +/* + * Print vector array - useful for debugging + */ +void +PrintVectorArray(char *msg, VectorArray arr) +{ + int i; + + for (i = 0; i < arr->length; i++) + PrintVector(msg, VectorArrayGet(arr, i)); +} + +/* + * Get the number of lists in the index + */ +int +IvfflatGetLists(Relation index) +{ + IvfflatOptions *opts = (IvfflatOptions *) index->rd_options; + + if (opts) + return opts->lists; + + return IVFFLAT_DEFAULT_LISTS; +} + +/* + * Get proc + */ +FmgrInfo * +IvfflatOptionalProcInfo(Relation rel, uint16 procnum) +{ + if (!OidIsValid(index_getprocid(rel, 1, procnum))) + return NULL; + + return index_getprocinfo(rel, 1, procnum); +} + +/* + * Divide by the norm + * + * Returns false if value should not be indexed + * + * The caller needs to free the pointer stored in value + * if it's different than the original value + */ +bool +IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result) +{ + Vector *v; + int i; + double norm; + + norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); + + if (norm > 0) + { + v = (Vector *) DatumGetPointer(*value); + + if (result == NULL) + result = InitVector(v->dim); + + for (i = 0; i < v->dim; i++) + result->x[i] = v->x[i] / norm; + + *value = PointerGetDatum(result); + + return true; + } + + return false; +} + +/* + * New buffer + */ +Buffer +IvfflatNewBuffer(Relation index, ForkNumber forkNum) +{ + Buffer buf = ReadBufferExtended(index, forkNum, P_NEW, RBM_NORMAL, NULL); + + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + return buf; +} + +/* + * Init page + */ +void +IvfflatInitPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state) +{ + *state = GenericXLogStart(index); + *page = GenericXLogRegisterBuffer(*state, *buf, GENERIC_XLOG_FULL_IMAGE); + PageInit(*page, BufferGetPageSize(*buf), sizeof(IvfflatPageOpaqueData)); + IvfflatPageGetOpaque(*page)->nextblkno = InvalidBlockNumber; + IvfflatPageGetOpaque(*page)->page_id = IVFFLAT_PAGE_ID; +} + +/* + * Commit buffer + */ +void +IvfflatCommitBuffer(Buffer buf, GenericXLogState *state) +{ + MarkBufferDirty(buf); + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); +} + +/* + * Add a new page + * + * The order is very important!! + */ +void +IvfflatAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, ForkNumber forkNum) +{ + Buffer prevbuf = *buf; + + /* Get new buffer */ + *buf = IvfflatNewBuffer(index, forkNum); + + /* Update and commit previous buffer */ + IvfflatPageGetOpaque(*page)->nextblkno = BufferGetBlockNumber(*buf); + IvfflatCommitBuffer(prevbuf, *state); + + /* Init new page */ + IvfflatInitPage(index, buf, page, state); +} + +/* + * Update the start or insert page of a list + */ +void +IvfflatUpdateList(Relation index, GenericXLogState *state, ListInfo listInfo, + BlockNumber insertPage, BlockNumber startPage, ForkNumber forkNum) +{ + Buffer buf; + Page page; + IvfflatList list; + + buf = ReadBufferExtended(index, forkNum, listInfo.blkno, RBM_NORMAL, NULL); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + list = (IvfflatList) PageGetItem(page, PageGetItemId(page, listInfo.offno)); + + if (BlockNumberIsValid(insertPage)) + list->insertPage = insertPage; + + if (BlockNumberIsValid(startPage)) + list->startPage = startPage; + + /* Could only commit if changed, but extra complexity isn't needed */ + IvfflatCommitBuffer(buf, state); +} diff --git a/ivfvacuum.c b/ivfvacuum.c new file mode 100644 index 0000000..7c27bdf --- /dev/null +++ b/ivfvacuum.c @@ -0,0 +1,151 @@ +#include "postgres.h" + +#include "commands/vacuum.h" +#include "ivfflat.h" +#include "storage/bufmgr.h" + +/* + * Bulk delete tuples from the index + */ +IndexBulkDeleteResult * +ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, + IndexBulkDeleteCallback callback, void *callback_state) +{ + Relation index = info->index; + Buffer cbuf; + Page cpage; + Buffer buf; + Page page; + IvfflatList list; + IndexTuple itup; + ItemPointer htup; + OffsetNumber deletable[MaxOffsetNumber]; + int ndeletable; + OffsetNumber startPages[MaxOffsetNumber]; + BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO; + BlockNumber searchPage; + BlockNumber insertPage; + GenericXLogState *state; + OffsetNumber coffno; + OffsetNumber cmaxoffno; + OffsetNumber offno; + OffsetNumber maxoffno; + ListInfo listInfo; + BufferAccessStrategy bas = GetAccessStrategy(BAS_BULKREAD); + + if (stats == NULL) + stats = (IndexBulkDeleteResult *) palloc0(sizeof(IndexBulkDeleteResult)); + + /* Iterate over list pages */ + while (BlockNumberIsValid(nextblkno)) + { + cbuf = ReadBuffer(index, nextblkno); + LockBuffer(cbuf, BUFFER_LOCK_SHARE); + cpage = BufferGetPage(cbuf); + + cmaxoffno = PageGetMaxOffsetNumber(cpage); + + /* Iterate over lists */ + for (coffno = FirstOffsetNumber; coffno <= cmaxoffno; coffno = OffsetNumberNext(coffno)) + { + list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, coffno)); + startPages[coffno - FirstOffsetNumber] = list->startPage; + } + + listInfo.blkno = nextblkno; + nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno; + + UnlockReleaseBuffer(cbuf); + + for (coffno = FirstOffsetNumber; coffno <= cmaxoffno; coffno = OffsetNumberNext(coffno)) + { + searchPage = startPages[coffno - FirstOffsetNumber]; + insertPage = InvalidBlockNumber; + + /* Iterate over entry pages */ + while (BlockNumberIsValid(searchPage)) + { + vacuum_delay_point(); + + buf = ReadBufferExtended(index, MAIN_FORKNUM, searchPage, RBM_NORMAL, bas); + + /* + * ambulkdelete cannot delete entries from pages that are + * pinned by other backends + * + * https://www.postgresql.org/docs/current/index-locking.html + */ + LockBufferForCleanup(buf); + + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + maxoffno = PageGetMaxOffsetNumber(page); + ndeletable = 0; + + /* Find deleted tuples */ + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + itup = (IndexTuple) PageGetItem(page, PageGetItemId(page, offno)); + htup = &(itup->t_tid); + + if (callback(htup, callback_state)) + { + deletable[ndeletable++] = offno; + stats->tuples_removed++; + } + else + stats->num_index_tuples++; + } + + searchPage = IvfflatPageGetOpaque(page)->nextblkno; + + if (ndeletable > 0) + { + /* Delete tuples */ + PageIndexMultiDelete(page, deletable, ndeletable); + MarkBufferDirty(buf); + GenericXLogFinish(state); + + /* Set to first free page */ + if (!BlockNumberIsValid(insertPage)) + insertPage = searchPage; + } + else + GenericXLogAbort(state); + + UnlockReleaseBuffer(buf); + } + + /* + * Update after all tuples deleted. + * + * We don't add or delete items from lists pages, so offset won't + * change. + */ + if (!BlockNumberIsValid(insertPage)) + { + listInfo.offno = coffno; + IvfflatUpdateList(index, state, listInfo, insertPage, InvalidBlockNumber, MAIN_FORKNUM); + } + } + } + + return stats; +} + +/* + * Clean up after a VACUUM operation + */ +IndexBulkDeleteResult * +ivfflatvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats) +{ + Relation rel = info->index; + + if (stats == NULL) + return NULL; + + stats->num_pages = RelationGetNumberOfBlocks(rel); + + return stats; +} diff --git a/sql/btree.sql b/sql/btree.sql new file mode 100644 index 0000000..32d0e4b --- /dev/null +++ b/sql/btree.sql @@ -0,0 +1,12 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t (val); + +SELECT * FROM t WHERE val = '[1,2,3]'; +SELECT * FROM t ORDER BY val LIMIT 1; + +DROP TABLE t; diff --git a/sql/cast.sql b/sql/cast.sql new file mode 100644 index 0000000..a1bf746 --- /dev/null +++ b/sql/cast.sql @@ -0,0 +1,11 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; + +SELECT ARRAY[1,2,3]::vector; +SELECT ARRAY[1,2,3]::float4[]::vector; +SELECT ARRAY[1,2,3]::float8[]::vector; +SELECT '{NULL}'::real[]::vector; +SELECT '{NaN}'::real[]::vector; +SELECT '{Infinity}'::real[]::vector; +SELECT '{-Infinity}'::real[]::vector; +SELECT '{}'::real[]::vector; diff --git a/sql/functions.sql b/sql/functions.sql new file mode 100644 index 0000000..7f397a2 --- /dev/null +++ b/sql/functions.sql @@ -0,0 +1,18 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; + +SELECT '[1,2,3]'::vector + '[4,5,6]'; +SELECT '[1,2,3]'::vector - '[4,5,6]'; + +SELECT vector_dims('[1,2,3]'); +SELECT round(vector_norm('[1,1]')::numeric, 5); + +SELECT round(l2_distance('[1,2]', '[0,0]')::numeric, 5); +SELECT l2_distance('[1,2]', '[3]'); + +SELECT inner_product('[1,2]', '[3,4]'); +SELECT inner_product('[1,2]', '[3]'); + +SELECT round(cosine_distance('[1,2]', '[2,4]')::numeric, 5); +SELECT cosine_distance('[1,2]', '[0,0]'); +SELECT cosine_distance('[1,2]', '[3]'); diff --git a/sql/ivfflat_cosine.sql b/sql/ivfflat_cosine.sql new file mode 100644 index 0000000..cc4522a --- /dev/null +++ b/sql/ivfflat_cosine.sql @@ -0,0 +1,14 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val vector_cosine_ops) WITH (lists = 1); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <=> '[3,3,3]'; +SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector); + +DROP TABLE t; diff --git a/sql/ivfflat_ip.sql b/sql/ivfflat_ip.sql new file mode 100644 index 0000000..342f40f --- /dev/null +++ b/sql/ivfflat_ip.sql @@ -0,0 +1,14 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val vector_ip_ops) WITH (lists = 1); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <#> '[3,3,3]'; +SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector); + +DROP TABLE t; diff --git a/sql/ivfflat_l2.sql b/sql/ivfflat_l2.sql new file mode 100644 index 0000000..336434e --- /dev/null +++ b/sql/ivfflat_l2.sql @@ -0,0 +1,14 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val) WITH (lists = 1); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; +SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); + +DROP TABLE t; diff --git a/sql/ivfflat_options.sql b/sql/ivfflat_options.sql new file mode 100644 index 0000000..1e2c1b0 --- /dev/null +++ b/sql/ivfflat_options.sql @@ -0,0 +1,11 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +CREATE INDEX ON t USING ivfflat (val) WITH (lists = 0); +CREATE INDEX ON t USING ivfflat (val) WITH (lists = 32769); + +SHOW ivfflat.probes; + +DROP TABLE t; diff --git a/sql/ivfflat_unlogged.sql b/sql/ivfflat_unlogged.sql new file mode 100644 index 0000000..8fa426f --- /dev/null +++ b/sql/ivfflat_unlogged.sql @@ -0,0 +1,11 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; +SET enable_seqscan = off; + +CREATE UNLOGGED TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val) WITH (lists = 1); + +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + +DROP TABLE t; diff --git a/sql/vector.sql b/sql/vector.sql new file mode 100644 index 0000000..75af74e --- /dev/null +++ b/sql/vector.sql @@ -0,0 +1,15 @@ +SET client_min_messages = warning; +CREATE EXTENSION IF NOT EXISTS vector; + +SELECT '[1,2,3]'::vector; +SELECT '[-1,2,3]'::vector; +SELECT '[hello,1]'::vector; +SELECT '[NaN,1]'::vector; +SELECT '[Infinity,1]'::vector; +SELECT '[-Infinity,1]'::vector; +SELECT '[1,2,3'::vector; +SELECT '[1,2,3]9'::vector; +SELECT '1,2,3'::vector; +SELECT '[]'::vector; +SELECT '[1,]'::vector; +SELECT '[1,2,3]'::vector(2); diff --git a/t/001_wal.pl b/t/001_wal.pl new file mode 100644 index 0000000..e1fe5fe --- /dev/null +++ b/t/001_wal.pl @@ -0,0 +1,80 @@ +# Based on postgres/contrib/bloom/t/001_wal.pl + +# Test generic xlog record work for ivfflat index replication. +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More tests => 31; + +my $node_primary; +my $node_replica; + +# Run few queries on both primary and replica and check their results match. +sub test_index_replay +{ + my ($test_name) = @_; + + # Wait for replica to catch up + my $applname = $node_replica->name; + my $caughtup_query = + "SELECT pg_current_wal_lsn() <= write_lsn FROM pg_stat_replication WHERE application_name = '$applname';"; + $node_primary->poll_query_until('postgres', $caughtup_query) + or die "Timed out while waiting for replica 1 to catch up"; + + my $r1 = rand(); + my $r2 = rand(); + my $r3 = rand(); + + my $queries = qq(SET enable_seqscan=off; +SELECT * FROM tst ORDER BY v <-> '[$r1,$r2,$r3]' LIMIT 10; +); + + # Run test queries and compare their result + my $primary_result = $node_primary->safe_psql("postgres", $queries); + my $replica_result = $node_replica->safe_psql("postgres", $queries); + + is($primary_result, $replica_result, "$test_name: query result matches"); + return; +} + +# Initialize primary node +$node_primary = get_new_node('primary'); +$node_primary->init(allows_streaming => 1); +$node_primary->start; +my $backup_name = 'my_backup'; + +# Take backup +$node_primary->backup($backup_name); + +# Create streaming replica linking to primary +$node_replica = get_new_node('replica'); +$node_replica->init_from_backup($node_primary, $backup_name, + has_streaming => 1); +$node_replica->start; + +# Create ivfflat index on primary +$node_primary->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node_primary->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));"); +$node_primary->safe_psql("postgres", + "INSERT INTO tst SELECT i%10, ARRAY[random(), random(), random()] FROM generate_series(1,100000) i;" +); +$node_primary->safe_psql("postgres", + "CREATE INDEX ON tst USING ivfflat (v);"); + +# Test that queries give same result +test_index_replay('initial'); + +# Run 10 cycles of table modification. Run test queries after each modification. +for my $i (1 .. 10) +{ + $node_primary->safe_psql("postgres", "DELETE FROM tst WHERE i = $i;"); + test_index_replay("delete $i"); + $node_primary->safe_psql("postgres", "VACUUM tst;"); + test_index_replay("vacuum $i"); + my ($start, $end) = (100001 + ($i - 1) * 10000, 100000 + $i * 10000); + $node_primary->safe_psql("postgres", + "INSERT INTO tst SELECT i%10, ARRAY[random(), random(), random()] FROM generate_series($start,$end) i;" + ); + test_index_replay("insert $i"); +} diff --git a/vector--0.1.0.sql b/vector--0.1.0.sql new file mode 100644 index 0000000..8e7a064 --- /dev/null +++ b/vector--0.1.0.sql @@ -0,0 +1,210 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "CREATE EXTENSION vector" to load this file. \quit + +-- type + +CREATE TYPE vector; + +CREATE FUNCTION vector_in(cstring, oid, integer) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION vector_out(vector) RETURNS cstring + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION vector_typmod_in(cstring[]) RETURNS integer + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE TYPE vector ( + INPUT = vector_in, + OUTPUT = vector_out, + TYPMOD_IN = vector_typmod_in +); + +-- functions + +CREATE FUNCTION l2_distance(vector, vector) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION inner_product(vector, vector) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION cosine_distance(vector, vector) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION vector_dims(vector) RETURNS integer + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION vector_norm(vector) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION vector_add(vector, vector) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION vector_sub(vector, vector) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +-- private functions + +CREATE FUNCTION vector_lt(vector, vector) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION vector_le(vector, vector) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION vector_eq(vector, vector) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION vector_ne(vector, vector) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION vector_ge(vector, vector) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION vector_gt(vector, vector) RETURNS bool + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION vector_cmp(vector, vector) RETURNS int4 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION vector_l2_squared_distance(vector, vector) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION vector_negative_inner_product(vector, vector) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION vector_spherical_distance(vector, vector) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +-- cast functions + +CREATE FUNCTION vector(vector, integer, boolean) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION array_to_vector(integer[], integer, boolean) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION array_to_vector(real[], integer, boolean) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +CREATE FUNCTION array_to_vector(double precision[], integer, boolean) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; + +-- casts + +CREATE CAST (vector AS vector) + WITH FUNCTION vector(vector, integer, boolean) AS IMPLICIT; + +CREATE CAST (integer[] AS vector) + WITH FUNCTION array_to_vector(integer[], integer, boolean) AS IMPLICIT; + +CREATE CAST (real[] AS vector) + WITH FUNCTION array_to_vector(real[], integer, boolean) AS IMPLICIT; + +CREATE CAST (double precision[] AS vector) + WITH FUNCTION array_to_vector(double precision[], integer, boolean) AS IMPLICIT; + +-- operators + +CREATE OPERATOR <-> ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = l2_distance, + COMMUTATOR = '<->' +); + +CREATE OPERATOR <#> ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_negative_inner_product, + COMMUTATOR = '<#>' +); + +CREATE OPERATOR <=> ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = cosine_distance, + COMMUTATOR = '<=>' +); + +CREATE OPERATOR + ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_add, + COMMUTATOR = + +); + +CREATE OPERATOR - ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_sub, + COMMUTATOR = - +); + +CREATE OPERATOR < ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_lt, + COMMUTATOR = > , NEGATOR = >= , + RESTRICT = scalarltsel, JOIN = scalarltjoinsel +); + +-- should use scalarlesel and scalarlejoinsel, but not supported in Postgres < 11 +CREATE OPERATOR <= ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_le, + COMMUTATOR = >= , NEGATOR = > , + RESTRICT = scalarltsel, JOIN = scalarltjoinsel +); + +CREATE OPERATOR = ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_eq, + COMMUTATOR = = , NEGATOR = <> , + RESTRICT = eqsel, JOIN = eqjoinsel +); + +CREATE OPERATOR <> ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_ne, + COMMUTATOR = <> , NEGATOR = = , + RESTRICT = eqsel, JOIN = eqjoinsel +); + +-- should use scalargesel and scalargejoinsel, but not supported in Postgres < 11 +CREATE OPERATOR >= ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_ge, + COMMUTATOR = <= , NEGATOR = < , + RESTRICT = scalargtsel, JOIN = scalargtjoinsel +); + +CREATE OPERATOR > ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_gt, + COMMUTATOR = < , NEGATOR = <= , + RESTRICT = scalargtsel, JOIN = scalargtjoinsel +); + +-- access method + +CREATE FUNCTION ivfflathandler(internal) RETURNS index_am_handler + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE ACCESS METHOD ivfflat TYPE INDEX HANDLER ivfflathandler; + +COMMENT ON ACCESS METHOD ivfflat IS 'ivfflat index access method'; + +-- opclasses + +CREATE OPERATOR CLASS vector_ops + DEFAULT FOR TYPE vector USING btree AS + OPERATOR 1 < , + OPERATOR 2 <= , + OPERATOR 3 = , + OPERATOR 4 >= , + OPERATOR 5 > , + FUNCTION 1 vector_cmp(vector, vector); + +CREATE OPERATOR CLASS vector_l2_ops + DEFAULT FOR TYPE vector USING ivfflat AS + OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_l2_squared_distance(vector, vector), + FUNCTION 3 l2_distance(vector, vector); + +CREATE OPERATOR CLASS vector_ip_ops + FOR TYPE vector USING ivfflat AS + OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_negative_inner_product(vector, vector), + FUNCTION 3 vector_spherical_distance(vector, vector), + FUNCTION 4 vector_norm(vector); + +CREATE OPERATOR CLASS vector_cosine_ops + FOR TYPE vector USING ivfflat AS + OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_negative_inner_product(vector, vector), + FUNCTION 2 vector_norm(vector), + FUNCTION 3 vector_spherical_distance(vector, vector), + FUNCTION 4 vector_norm(vector); diff --git a/vector.c b/vector.c new file mode 100644 index 0000000..fcdfcc3 --- /dev/null +++ b/vector.c @@ -0,0 +1,610 @@ +#include "postgres.h" + +#include + +#include "vector.h" +#include "fmgr.h" +#include "catalog/pg_type.h" +#include "lib/stringinfo.h" +#include "utils/array.h" +#include "utils/builtins.h" +#include "utils/lsyscache.h" + +#if PG_VERSION_NUM >= 120000 +#include "utils/float.h" +#endif + +PG_MODULE_MAGIC; + +/* + * Ensure same dimensions + */ +static inline void +CheckDims(Vector * a, Vector * b) +{ + if (a->dim != b->dim) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("different vector dimensions %d and %d", a->dim, b->dim))); +} + +/* + * Ensure expected dimension + */ +static inline void +CheckExpectedDim(int32 typmod, int dim) +{ + if (typmod != -1 && typmod != dim) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("expected %d dimensions, not %d", typmod, dim))); +} + +/* + * Ensure finite elements + */ +static inline void +CheckElement(float value) +{ + if (isnan(value)) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("NaN not allowed in vector"))); + + + if (isinf(value)) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("infinite value not allowed in vector"))); +} + +/* + * Print vector - useful for debugging + */ +void +PrintVector(char *msg, Vector * vector) +{ + StringInfoData buf; + int dim = vector->dim; + int i; + + initStringInfo(&buf); + + appendStringInfoChar(&buf, '['); + for (i = 0; i < dim; i++) + { + if (i > 0) + appendStringInfoString(&buf, ","); + appendStringInfoString(&buf, float8out_internal(vector->x[i])); + } + appendStringInfoChar(&buf, ']'); + + elog(INFO, "%s = %s", msg, buf.data); +} + +/* + * Convert textual representation to internal representation + */ +PG_FUNCTION_INFO_V1(vector_in); +Datum +vector_in(PG_FUNCTION_ARGS) +{ + char *str = PG_GETARG_CSTRING(0); + int32 typmod = PG_GETARG_INT32(2); + int i; + double x[VECTOR_MAX_DIM]; + int dim = 0; + char *pt; + char *stringEnd; + Vector *result; + + if (*str != '[') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("malformed vector literal: \"%s\"", str), + errdetail("Vector contents must start with \"[\"."))); + + str++; + pt = strtok(str, ","); + stringEnd = pt; + + while (pt != NULL && *stringEnd != ']') + { + if (dim == VECTOR_MAX_DIM) + ereport(ERROR, + (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), + errmsg("vector cannot have more than %d dimensions", VECTOR_MAX_DIM))); + + x[dim] = strtod(pt, &stringEnd); + CheckElement(x[dim]); + dim++; + + if (stringEnd == pt) + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("invalid input syntax for type vector: \"%s\"", pt))); + + if (*stringEnd != '\0' && *stringEnd != ']') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("invalid input syntax for type vector: \"%s\"", pt))); + + pt = strtok(NULL, ","); + } + + if (*stringEnd != ']') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("malformed vector literal"), + errdetail("Unexpected end of input."))); + + if (stringEnd[1] != '\0') + ereport(ERROR, + (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), + errmsg("malformed vector literal"), + errdetail("Junk after closing right brace."))); + + if (dim < 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("vector must have at least 1 dimension"))); + + CheckExpectedDim(typmod, dim); + + result = InitVector(dim); + for (i = 0; i < dim; i++) + result->x[i] = x[i]; + + PG_RETURN_POINTER(result); +} + +/* + * Convert internal representation to textual representation + */ +PG_FUNCTION_INFO_V1(vector_out); +Datum +vector_out(PG_FUNCTION_ARGS) +{ + Vector *vector = PG_GETARG_VECTOR_P(0); + StringInfoData buf; + int dim = vector->dim; + int i; + + initStringInfo(&buf); + + appendStringInfoChar(&buf, '['); + for (i = 0; i < dim; i++) + { + if (i > 0) + appendStringInfoString(&buf, ","); + + appendStringInfoString(&buf, float8out_internal(vector->x[i])); + } + appendStringInfoChar(&buf, ']'); + + PG_FREE_IF_COPY(vector, 0); + PG_RETURN_CSTRING(buf.data); +} + +/* + * Convert type modifier + */ +PG_FUNCTION_INFO_V1(vector_typmod_in); +Datum +vector_typmod_in(PG_FUNCTION_ARGS) +{ + ArrayType *ta = PG_GETARG_ARRAYTYPE_P(0); + int32 *tl; + int n; + + tl = ArrayGetIntegerTypmods(ta, &n); + + if (n != 1) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("invalid type modifier"))); + + if (*tl < 1) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("dimensions for type vector must be at least 1"))); + + if (*tl > VECTOR_MAX_DIM) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("dimensions for type vector cannot exceed %d", VECTOR_MAX_DIM))); + + PG_RETURN_INT32(*tl); +} + +/* + * Convert vector to vector + */ +PG_FUNCTION_INFO_V1(vector); +Datum +vector(PG_FUNCTION_ARGS) +{ + Vector *arg = PG_GETARG_VECTOR_P(0); + int32 typmod = PG_GETARG_INT32(1); + + CheckExpectedDim(typmod, arg->dim); + + PG_RETURN_POINTER(arg); +} + +/* + * Convert array to vector + */ +PG_FUNCTION_INFO_V1(array_to_vector); +Datum +array_to_vector(PG_FUNCTION_ARGS) +{ + ArrayType *array = PG_GETARG_ARRAYTYPE_P(0); + int32 typmod = PG_GETARG_INT32(1); + int i; + Vector *result; + int16 typlen; + bool typbyval; + char typalign; + Datum *elemsp; + bool *nullsp; + int nelemsp; + + if (ARR_NDIM(array) > 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("array must be 1-D"))); + + get_typlenbyvalalign(ARR_ELEMTYPE(array), &typlen, &typbyval, &typalign); + deconstruct_array(array, ARR_ELEMTYPE(array), typlen, typbyval, typalign, &elemsp, &nullsp, &nelemsp); + + if (typmod == -1) + { + if (nelemsp < 1) + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("vector must have at least 1 dimension"))); + + if (nelemsp > VECTOR_MAX_DIM) + ereport(ERROR, + (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), + errmsg("vector cannot have more than %d dimensions", VECTOR_MAX_DIM))); + } + else + CheckExpectedDim(typmod, nelemsp); + + result = InitVector(nelemsp); + for (i = 0; i < nelemsp; i++) + { + if (nullsp[i]) + ereport(ERROR, + (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), + errmsg("array must not containing NULLs"))); + + if (ARR_ELEMTYPE(array) == INT4OID) + result->x[i] = DatumGetInt32(elemsp[i]); + else if (ARR_ELEMTYPE(array) == FLOAT8OID) + result->x[i] = DatumGetFloat8(elemsp[i]); + else if (ARR_ELEMTYPE(array) == FLOAT4OID) + result->x[i] = DatumGetFloat4(elemsp[i]); + else + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("unsupported array type"))); + + CheckElement(result->x[i]); + } + + PG_RETURN_POINTER(result); +} + +/* + * Get the L2 distance between vectors + */ +PG_FUNCTION_INFO_V1(l2_distance); +Datum +l2_distance(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + Vector *b = PG_GETARG_VECTOR_P(1); + double distance = 0.0; + + CheckDims(a, b); + + for (int i = 0; i < a->dim; i++) + distance += pow(a->x[i] - b->x[i], 2); + + PG_RETURN_FLOAT8(sqrt(distance)); +} + +/* + * Get the L2 squared distance between vectors + * This saves a sqrt calculation + */ +PG_FUNCTION_INFO_V1(vector_l2_squared_distance); +Datum +vector_l2_squared_distance(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + Vector *b = PG_GETARG_VECTOR_P(1); + double distance = 0.0; + + CheckDims(a, b); + + for (int i = 0; i < a->dim; i++) + distance += pow(a->x[i] - b->x[i], 2); + + PG_RETURN_FLOAT8(distance); +} + +/* + * Get the inner product of two vectors + */ +PG_FUNCTION_INFO_V1(inner_product); +Datum +inner_product(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + Vector *b = PG_GETARG_VECTOR_P(1); + double distance = 0.0; + + CheckDims(a, b); + + for (int i = 0; i < a->dim; i++) + distance += a->x[i] * b->x[i]; + + PG_RETURN_FLOAT8(distance); +} + +/* + * Get the negative inner product of two vectors + */ +PG_FUNCTION_INFO_V1(vector_negative_inner_product); +Datum +vector_negative_inner_product(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + Vector *b = PG_GETARG_VECTOR_P(1); + double distance = 0.0; + + CheckDims(a, b); + + for (int i = 0; i < a->dim; i++) + distance += a->x[i] * b->x[i]; + + PG_RETURN_FLOAT8(distance * -1); +} + +/* + * Get the cosine distance between two vectors + */ +PG_FUNCTION_INFO_V1(cosine_distance); +Datum +cosine_distance(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + Vector *b = PG_GETARG_VECTOR_P(1); + double distance = 0.0; + double norma = 0.0; + double normb = 0.0; + + CheckDims(a, b); + + for (int i = 0; i < a->dim; i++) + { + distance += a->x[i] * b->x[i]; + norma += pow(a->x[i], 2); + normb += pow(b->x[i], 2); + } + + PG_RETURN_FLOAT8(1 - (distance / (sqrt(norma) * sqrt(normb)))); +} + +/* + * Get the distance for spherical k-means + * Currently uses angular distance since needs to satisfy triangle inequality + * Assumes inputs are unit vectors (skips norm) + */ +PG_FUNCTION_INFO_V1(vector_spherical_distance); +Datum +vector_spherical_distance(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + Vector *b = PG_GETARG_VECTOR_P(1); + double distance = 0.0; + + CheckDims(a, b); + + for (int i = 0; i < a->dim; i++) + distance += a->x[i] * b->x[i]; + + /* Prevent NaN with acos with loss of precision */ + if (distance > 1) + distance = 1; + else if (distance < -1) + distance = -1; + + PG_RETURN_FLOAT8(acos(distance) / M_PI); +} + +/* + * Get the dimensions of a vector + */ +PG_FUNCTION_INFO_V1(vector_dims); +Datum +vector_dims(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + + PG_RETURN_INT32(a->dim); +} + +/* + * Get the L2 norm of a vector + */ +PG_FUNCTION_INFO_V1(vector_norm); +Datum +vector_norm(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + double norm = 0.0; + + for (int i = 0; i < a->dim; i++) + norm += pow(a->x[i], 2); + + PG_RETURN_FLOAT8(sqrt(norm)); +} + +/* + * Add vectors + */ +PG_FUNCTION_INFO_V1(vector_add); +Datum +vector_add(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + Vector *b = PG_GETARG_VECTOR_P(1); + Vector *result; + int i; + + CheckDims(a, b); + + result = InitVector(a->dim); + for (i = 0; i < a->dim; i++) + result->x[i] = a->x[i] + b->x[i]; + + PG_RETURN_POINTER(result); +} + +/* + * Subtract vectors + */ +PG_FUNCTION_INFO_V1(vector_sub); +Datum +vector_sub(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + Vector *b = PG_GETARG_VECTOR_P(1); + Vector *result; + int i; + + CheckDims(a, b); + + result = InitVector(a->dim); + for (i = 0; i < a->dim; i++) + result->x[i] = a->x[i] - b->x[i]; + + PG_RETURN_POINTER(result); +} + +/* + * Internal helper to compare vectors + */ +int +vector_cmp_internal(Vector * a, Vector * b) +{ + int i; + + CheckDims(a, b); + + for (i = 0; i < a->dim; i++) + { + if (a->x[i] < b->x[i]) + return -1; + + if (a->x[i] > b->x[i]) + return 1; + } + return 0; +} + +/* + * Less than + */ +PG_FUNCTION_INFO_V1(vector_lt); +Datum +vector_lt(PG_FUNCTION_ARGS) +{ + Vector *a = (Vector *) PG_GETARG_VECTOR_P(0); + Vector *b = (Vector *) PG_GETARG_VECTOR_P(1); + + PG_RETURN_BOOL(vector_cmp_internal(a, b) < 0); +} + +/* + * Less than or equal + */ +PG_FUNCTION_INFO_V1(vector_le); +Datum +vector_le(PG_FUNCTION_ARGS) +{ + Vector *a = (Vector *) PG_GETARG_VECTOR_P(0); + Vector *b = (Vector *) PG_GETARG_VECTOR_P(1); + + PG_RETURN_BOOL(vector_cmp_internal(a, b) <= 0); +} + +/* + * Equal + */ +PG_FUNCTION_INFO_V1(vector_eq); +Datum +vector_eq(PG_FUNCTION_ARGS) +{ + Vector *a = (Vector *) PG_GETARG_VECTOR_P(0); + Vector *b = (Vector *) PG_GETARG_VECTOR_P(1); + + PG_RETURN_BOOL(vector_cmp_internal(a, b) == 0); +} + +/* + * Not equal + */ +PG_FUNCTION_INFO_V1(vector_ne); +Datum +vector_ne(PG_FUNCTION_ARGS) +{ + Vector *a = (Vector *) PG_GETARG_VECTOR_P(0); + Vector *b = (Vector *) PG_GETARG_VECTOR_P(1); + + PG_RETURN_BOOL(vector_cmp_internal(a, b) != 0); +} + +/* + * Greater than or equal + */ +PG_FUNCTION_INFO_V1(vector_ge); +Datum +vector_ge(PG_FUNCTION_ARGS) +{ + Vector *a = (Vector *) PG_GETARG_VECTOR_P(0); + Vector *b = (Vector *) PG_GETARG_VECTOR_P(1); + + PG_RETURN_BOOL(vector_cmp_internal(a, b) >= 0); +} + +/* + * Greater than + */ +PG_FUNCTION_INFO_V1(vector_gt); +Datum +vector_gt(PG_FUNCTION_ARGS) +{ + Vector *a = (Vector *) PG_GETARG_VECTOR_P(0); + Vector *b = (Vector *) PG_GETARG_VECTOR_P(1); + + PG_RETURN_BOOL(vector_cmp_internal(a, b) > 0); +} + +/* + * Compare vectors + */ +PG_FUNCTION_INFO_V1(vector_cmp); +Datum +vector_cmp(PG_FUNCTION_ARGS) +{ + Vector *a = (Vector *) PG_GETARG_VECTOR_P(0); + Vector *b = (Vector *) PG_GETARG_VECTOR_P(1); + + PG_RETURN_INT32(vector_cmp_internal(a, b)); +} diff --git a/vector.control b/vector.control new file mode 100644 index 0000000..588b75c --- /dev/null +++ b/vector.control @@ -0,0 +1,4 @@ +comment = 'vector data type and ivfflat access method' +default_version = '0.1.0' +module_pathname = '$libdir/vector' +relocatable = true diff --git a/vector.h b/vector.h new file mode 100644 index 0000000..cf14b79 --- /dev/null +++ b/vector.h @@ -0,0 +1,41 @@ +#ifndef VECTOR_H +#define VECTOR_H + +#include "postgres.h" + +#define VECTOR_MAX_DIM 1024 + +#define VECTOR_SIZE(_dim) (offsetof(Vector, x) + sizeof(float)*(_dim)) +#define DatumGetVector(x) ((Vector *) PG_DETOAST_DATUM(x)) +#define PG_GETARG_VECTOR_P(x) DatumGetVector(PG_GETARG_DATUM(x)) +#define PG_RETURN_VECTOR_P(x) PG_RETURN_POINTER(x) + +typedef struct Vector +{ + int32 vl_len_; /* varlena header (do not touch directly!) */ + int16 dim; /* number of dimensions */ + int16 unused; + float x[FLEXIBLE_ARRAY_MEMBER]; +} Vector; + +void PrintVector(char *msg, Vector * vector); +int vector_cmp_internal(Vector * a, Vector * b); + +/* + * Allocate and initialize a new vector + */ +static inline Vector * +InitVector(int dim) +{ + Vector *result; + int size; + + size = VECTOR_SIZE(dim); + result = (Vector *) palloc0(size); + SET_VARSIZE(result, size); + result->dim = dim; + + return result; +} + +#endif