First commit

This commit is contained in:
Andrew Kane
2021-04-20 14:04:28 -07:00
commit 6df7fa05b2
37 changed files with 3724 additions and 0 deletions

6
.editorconfig Normal file
View File

@@ -0,0 +1,6 @@
root = true
[*.{c,h}]
indent_style = tab
indent_size = tab
tab_width = 4

20
.github/workflows/build.yml vendored Normal file
View File

@@ -0,0 +1,20 @@
name: build
on: [push, pull_request]
jobs:
build:
if: "!contains(github.event.head_commit.message, '[skip ci]')"
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
postgres: [13, 12, 11, 10, 9.6]
steps:
- uses: actions/checkout@v2
- uses: ankane/setup-postgres@v1
with:
postgres-version: ${{ matrix.postgres }}
- run: sudo apt-get install postgresql-server-dev-${{ matrix.postgres }} libipc-run-perl
- run: make
- run: sudo make install
- run: make installcheck
- run: make prove_installcheck

5
.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
results
tmp_check
regression.*
*.o
*.so

3
CHANGELOG.md Normal file
View File

@@ -0,0 +1,3 @@
## 0.1.0 (unreleased)
- First release

20
LICENSE Normal file
View File

@@ -0,0 +1,20 @@
Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
Portions Copyright (c) 1994, The Regents of the University of California
Permission to use, copy, modify, and distribute this software and its
documentation for any purpose, without fee, and without a written agreement
is hereby granted, provided that the above copyright notice and this
paragraph and the following two paragraphs appear in all copies.
IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR
DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING
LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS
DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES,
INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY
AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS
ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO
PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.

16
Makefile Normal file
View File

@@ -0,0 +1,16 @@
EXTENSION = vector
DATA = vector--0.1.0.sql
MODULE_big = vector
OBJS = ivfbuild.o ivfflat.o ivfinsert.o ivfkmeans.o ivfscan.o ivfutils.o ivfvacuum.o vector.o
TESTS = $(wildcard sql/*.sql)
REGRESS = btree cast functions ivfflat_cosine ivfflat_ip ivfflat_l2 ivfflat_unlogged vector
PG_CONFIG ?= pg_config
PGXS := $(shell $(PG_CONFIG) --pgxs)
include $(PGXS)
prove_installcheck:
rm -rf $(CURDIR)/tmp_check
cd $(srcdir) && TESTDIR='$(CURDIR)' PATH="$(bindir):$$PATH" PGPORT='6$(DEF_PGPORT)' PG_REGRESS='$(top_builddir)/src/test/regress/pg_regress' $(PROVE) $(PG_PROVE_FLAGS) $(PROVE_FLAGS) $(if $(PROVE_TESTS),$(PROVE_TESTS),t/*.pl)

189
README.md Normal file
View File

@@ -0,0 +1,189 @@
# pgvector
Open-source vector similarity search for Postgres
```sql
CREATE TABLE table (column vector(3));
CREATE INDEX ON table USING ivfflat (column);
SELECT * FROM table ORDER BY column <-> '[1,2,3]' LIMIT 5;
```
Supports L2 distance, inner product, and cosine distance
[![Build Status](https://github.com/ankane/pgvector/workflows/build/badge.svg?branch=master)](https://github.com/ankane/pgvector/actions)
## Installation
Compile and install the extension (supports Postgres 9.6+)
```sh
git clone https://github.com/ankane/pgvector.git
cd pgvector
make
make install # may need sudo
```
Then load it in databases where you want to use it
```sql
CREATE EXTENSION vector;
```
## Getting Started
Create a vector column with 3 dimensions (replace `table` and `column` with non-reserved names)
```sql
CREATE TABLE table (column vector(3));
```
Insert values
```sql
INSERT INTO table VALUES ('[1,2,3]'), ('[4,5,6]');
```
Get the nearest neighbor by L2 distance
```sql
SELECT * FROM table ORDER BY column <-> '[3,1,2]' LIMIT 1;
```
Also supports inner product (`<#>`) and cosine distance (`<=>`)
Note: `<#>` returns the negative inner product since Postgres only supports `ASC` order index scans on operators
## Indexing
Speed up queries with an approximate index. Add an index for each distance function you want to use.
L2 distance
```sql
CREATE INDEX ON table USING ivfflat (column);
```
Inner product
```sql
CREATE INDEX ON table USING ivfflat (column vector_ip_ops);
```
Cosine distance
```sql
CREATE INDEX ON table USING ivfflat (column vector_cosine_ops);
```
Indexes should be created after the table has data for optimal clustering. Also, unlike typical indexes which only affect performance, you may see different results for queries after adding an approximate index.
### Index Options
Specify the number of inverted lists (100 by default)
```sql
CREATE INDEX ON table USING ivfflat (column) WITH (lists = 100);
```
### Query Options
Specify the number of probes (1 by default)
```sql
SET ivfflat.probes = 1;
```
A higher value improves recall at the cost of speed.
Use `SET LOCAL` inside a transaction to set it for a single query
```sql
BEGIN;
SET LOCAL ivfflat.probes = 1;
SELECT ...
COMMIT;
```
## Reference
### Vector Type
Each vector takes `4 * dimensions + 8` bytes of storage. Each element is a float, and all elements must be finite (no `NaN`, `Infinity` or `-Infinity`). Vectors can have up to 1024 dimensions.
### Vector Operators
Operator | Description
--- | ---
\+ | element-wise addition
\- | element-wise subtraction
<-> | Euclidean distance
<#> | negative inner product
<=> | cosine distance
### Vector Functions
Function | Description
--- | ---
cosine_distance(vector, vector) | cosine distance
inner_product(vector, vector) | inner product
l2_distance(vector, vector) | Euclidean distance
vector_dims(vector) | number of dimensions
vector_norm(vector) | Euclidean norm
## Thanks
Thanks to:
- [PASE: PostgreSQL Ultra-High-Dimensional Approximate Nearest Neighbor Search Extension](https://dl.acm.org/doi/pdf/10.1145/3318464.3386131)
- [Faiss: A Library for Efficient Similarity Search and Clustering of Dense Vectors](https://github.com/facebookresearch/faiss)
- [Using the Triangle Inequality to Accelerate k-means](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf)
- [k-means++: The Advantage of Careful Seeding](https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf)
- [Concept Decompositions for Large Sparse Text Data using Clustering](https://www.cs.utexas.edu/users/inderjit/public_papers/concept_mlj.pdf)
## History
View the [changelog](https://github.com/ankane/pgvector/blob/master/CHANGELOG.md)
## Contributing
Everyone is encouraged to help improve this project. Here are a few ways you can help:
- [Report bugs](https://github.com/ankane/pgvector/issues)
- Fix bugs and [submit pull requests](https://github.com/ankane/pgvector/pulls)
- Write, clarify, or fix documentation
- Suggest or add new features
To get started with development:
```sh
git clone https://github.com/ankane/pgvector.git
cd pgvector
make
make install
```
To run all tests:
```sh
make installcheck # regression tests
make prove_installcheck # TAP tests
```
To run single tests:
```sh
make installcheck REGRESS=vector # regression test
make prove_installcheck PROVE_TESTS=t/001_wal.pl # TAP test
```
Directories
- `expected` - expected output for regression tests
- `sql` - regression tests
- `t` - TAP tests
Resources for contributors
- [Extension Building Infrastructure](https://www.postgresql.org/docs/current/extend-pgxs.html)
- [Index Access Method Interface Definition](https://www.postgresql.org/docs/current/indexam.html)
- [Generic WAL Records](https://www.postgresql.org/docs/13/generic-wal.html)

19
expected/btree.out Normal file
View File

@@ -0,0 +1,19 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t (val);
SELECT * FROM t WHERE val = '[1,2,3]';
val
---------
[1,2,3]
(1 row)
SELECT * FROM t ORDER BY val LIMIT 1;
val
---------
[0,0,0]
(1 row)
DROP TABLE t;

30
expected/cast.out Normal file
View File

@@ -0,0 +1,30 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SELECT ARRAY[1,2,3]::vector;
array
---------
[1,2,3]
(1 row)
SELECT ARRAY[1,2,3]::float4[]::vector;
array
---------
[1,2,3]
(1 row)
SELECT ARRAY[1,2,3]::float8[]::vector;
array
---------
[1,2,3]
(1 row)
SELECT '{NULL}'::real[]::vector;
ERROR: array must not containing NULLs
SELECT '{NaN}'::real[]::vector;
ERROR: NaN not allowed in vector
SELECT '{Infinity}'::real[]::vector;
ERROR: infinite value not allowed in vector
SELECT '{-Infinity}'::real[]::vector;
ERROR: infinite value not allowed in vector
SELECT '{}'::real[]::vector;
ERROR: vector must have at least 1 dimension

56
expected/functions.out Normal file
View File

@@ -0,0 +1,56 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SELECT '[1,2,3]'::vector + '[4,5,6]';
?column?
----------
[5,7,9]
(1 row)
SELECT '[1,2,3]'::vector - '[4,5,6]';
?column?
------------
[-3,-3,-3]
(1 row)
SELECT vector_dims('[1,2,3]');
vector_dims
-------------
3
(1 row)
SELECT round(vector_norm('[1,1]')::numeric, 5);
round
---------
1.41421
(1 row)
SELECT round(l2_distance('[1,2]', '[0,0]')::numeric, 5);
round
---------
2.23607
(1 row)
SELECT l2_distance('[1,2]', '[3]');
ERROR: different vector dimensions 2 and 1
SELECT inner_product('[1,2]', '[3,4]');
inner_product
---------------
11
(1 row)
SELECT inner_product('[1,2]', '[3]');
ERROR: different vector dimensions 2 and 1
SELECT round(cosine_distance('[1,2]', '[2,4]')::numeric, 5);
round
---------
0.00000
(1 row)
SELECT cosine_distance('[1,2]', '[0,0]');
cosine_distance
-----------------
NaN
(1 row)
SELECT cosine_distance('[1,2]', '[3]');
ERROR: different vector dimensions 2 and 1

View File

@@ -0,0 +1,21 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING ivfflat (val vector_cosine_ops) WITH (lists = 1);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <=> '[3,3,3]';
val
---------
[1,1,1]
[1,2,3]
[1,2,4]
(3 rows)
SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector);
val
-----
(0 rows)
DROP TABLE t;

22
expected/ivfflat_ip.out Normal file
View File

@@ -0,0 +1,22 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING ivfflat (val vector_ip_ops) WITH (lists = 1);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <#> '[3,3,3]';
val
---------
[1,2,4]
[1,2,3]
[1,1,1]
[0,0,0]
(4 rows)
SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector);
val
-----
(0 rows)
DROP TABLE t;

22
expected/ivfflat_l2.out Normal file
View File

@@ -0,0 +1,22 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING ivfflat (val) WITH (lists = 1);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <-> '[3,3,3]';
val
---------
[1,2,3]
[1,2,4]
[1,1,1]
[0,0,0]
(4 rows)
SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector);
val
-----
(0 rows)
DROP TABLE t;

View File

@@ -0,0 +1,15 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SET enable_seqscan = off;
CREATE UNLOGGED TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING ivfflat (val) WITH (lists = 1);
SELECT * FROM t ORDER BY val <-> '[3,3,3]';
val
---------
[1,2,3]
[1,1,1]
[0,0,0]
(3 rows)
DROP TABLE t;

55
expected/vector.out Normal file
View File

@@ -0,0 +1,55 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SELECT '[1,2,3]'::vector;
vector
---------
[1,2,3]
(1 row)
SELECT '[-1,2,3]'::vector;
vector
----------
[-1,2,3]
(1 row)
SELECT '[hello,1]'::vector;
ERROR: invalid input syntax for type vector: "hello"
LINE 1: SELECT '[hello,1]'::vector;
^
SELECT '[NaN,1]'::vector;
ERROR: NaN not allowed in vector
LINE 1: SELECT '[NaN,1]'::vector;
^
SELECT '[Infinity,1]'::vector;
ERROR: infinite value not allowed in vector
LINE 1: SELECT '[Infinity,1]'::vector;
^
SELECT '[-Infinity,1]'::vector;
ERROR: infinite value not allowed in vector
LINE 1: SELECT '[-Infinity,1]'::vector;
^
SELECT '[1,2,3'::vector;
ERROR: malformed vector literal
LINE 1: SELECT '[1,2,3'::vector;
^
DETAIL: Unexpected end of input.
SELECT '[1,2,3]9'::vector;
ERROR: malformed vector literal
LINE 1: SELECT '[1,2,3]9'::vector;
^
DETAIL: Junk after closing right brace.
SELECT '1,2,3'::vector;
ERROR: malformed vector literal: "1,2,3"
LINE 1: SELECT '1,2,3'::vector;
^
DETAIL: Vector contents must start with "[".
SELECT '[]'::vector;
ERROR: vector must have at least 1 dimension
LINE 1: SELECT '[]'::vector;
^
SELECT '[1,]'::vector;
ERROR: invalid input syntax for type vector: "]"
LINE 1: SELECT '[1,]'::vector;
^
SELECT '[1,2,3]'::vector(2);
ERROR: expected 2 dimensions, not 3

503
ivfbuild.c Normal file
View File

@@ -0,0 +1,503 @@
#include "postgres.h"
#include <float.h>
#include "catalog/index.h"
#include "ivfflat.h"
#include "miscadmin.h"
#include "storage/bufmgr.h"
#if PG_VERSION_NUM >= 120000
#include "access/tableam.h"
#endif
#if PG_VERSION_NUM >= 110000
#include "catalog/pg_operator_d.h"
#include "catalog/pg_type_d.h"
#else
#include "catalog/pg_operator.h"
#include "catalog/pg_type.h"
#endif
#if PG_VERSION_NUM >= 130000
#define CALLBACK_ITEM_POINTER ItemPointer tid
#else
#define CALLBACK_ITEM_POINTER HeapTuple hup
#endif
/*
* Callback for sampling
*/
static void
SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
bool *isnull, bool tupleIsAlive, void *state)
{
IvfflatBuildState *buildstate = (IvfflatBuildState *) state;
VectorArray samples = buildstate->samples;
int targsamples = samples->maxlen;
Datum value = values[0];
/* Skip nulls */
if (isnull[0])
return;
/* Normalize the value */
if (buildstate->normprocinfo != NULL)
{
if (!IvfflatNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->normvec))
return;
}
if (samples->length < targsamples)
{
VectorArraySet(samples, samples->length, DatumGetVector(value));
samples->length++;
}
else
{
if (buildstate->rowstoskip < 0)
buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, samples->length, targsamples);
if (buildstate->rowstoskip <= 0)
{
int k = (int) (targsamples * sampler_random_fract(buildstate->rstate.randstate));
Assert(k >= 0 && k < targsamples);
VectorArraySet(samples, k, DatumGetVector(value));
}
buildstate->rowstoskip -= 1;
}
}
/*
* Sample rows with same logic as ANALYZE
*/
static void
SampleRows(IvfflatBuildState * buildstate)
{
int targsamples = buildstate->samples->maxlen;
BlockNumber totalblocks = RelationGetNumberOfBlocks(buildstate->heap);
buildstate->rowstoskip = -1;
BlockSampler_Init(&buildstate->bs, totalblocks, targsamples, random());
reservoir_init_selection_state(&buildstate->rstate, targsamples);
while (BlockSampler_HasMore(&buildstate->bs))
{
BlockNumber targblock = BlockSampler_Next(&buildstate->bs);
#if PG_VERSION_NUM >= 120000
table_index_build_range_scan(buildstate->heap, buildstate->index, buildstate->indexInfo,
false, true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL);
#elif PG_VERSION_NUM >= 110000
IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL);
#else
IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, true, targblock, 1, SampleCallback, (void *) buildstate);
#endif
}
}
/*
* Callback for table_index_build_scan
*/
static void
BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
bool *isnull, bool tupleIsAlive, void *state)
{
IvfflatBuildState *buildstate = (IvfflatBuildState *) state;
double distance;
double minDistance = DBL_MAX;
int closestCenter = -1;
VectorArray centers = buildstate->centers;
TupleTableSlot *slot = buildstate->slot;
Datum value = values[0];
int i;
#if PG_VERSION_NUM < 130000
ItemPointer tid = &hup->t_self;
#endif
if (isnull[0])
return;
/* Normalize if needed */
if (buildstate->normprocinfo != NULL)
{
if (!IvfflatNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->normvec))
return;
}
/* Find the list that minimizes the distance */
for (i = 0; i < centers->length; i++)
{
distance = DatumGetFloat8(FunctionCall2Coll(buildstate->procinfo, buildstate->collation, value, PointerGetDatum(VectorArrayGet(centers, i))));
if (distance < minDistance)
{
minDistance = distance;
closestCenter = i;
}
}
/* Create a virtual tuple */
ExecClearTuple(slot);
slot->tts_values[0] = Int32GetDatum(closestCenter);
slot->tts_isnull[0] = false;
slot->tts_values[1] = Int32GetDatum(ItemPointerGetBlockNumberNoCheck(tid));
slot->tts_isnull[1] = false;
slot->tts_values[2] = Int32GetDatum(ItemPointerGetOffsetNumberNoCheck(tid));
slot->tts_isnull[2] = false;
slot->tts_values[3] = value;
slot->tts_isnull[3] = false;
ExecStoreVirtualTuple(slot);
/*
* Add tuple to sort
*
* tuplesort_puttupleslot comment: Input data is always copied; the caller
* need not save it.
*/
tuplesort_puttupleslot(buildstate->sortstate, slot);
}
/*
* Get index tuple from sort state
*/
static inline void
GetNextTuple(Tuplesortstate *sortstate, TupleDesc tupdesc, TupleTableSlot *slot, IndexTuple *itup, int *list)
{
Datum value;
bool isnull;
int tupblk;
int tupoff;
#if PG_VERSION_NUM >= 100000
if (tuplesort_gettupleslot(sortstate, true, false, slot, NULL))
#else
if (tuplesort_gettupleslot(sortstate, true, slot, NULL))
#endif
{
*list = DatumGetInt32(slot_getattr(slot, 1, &isnull));
tupblk = DatumGetInt32(slot_getattr(slot, 2, &isnull));
tupoff = DatumGetInt32(slot_getattr(slot, 3, &isnull));
value = slot_getattr(slot, 4, &isnull);
/* Form the index tuple */
*itup = index_form_tuple(tupdesc, &value, &isnull);
ItemPointerSet(&(*itup)->t_tid, tupblk, tupoff);
}
else
*list = -1;
}
/*
* Create initial entry pages
*/
static void
InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum)
{
Buffer buf;
Page page;
GenericXLogState *state;
int list;
IndexTuple itup = NULL; /* silence compiler warning */
BlockNumber startPage = InvalidBlockNumber;
BlockNumber insertPage = InvalidBlockNumber;
Size itemsz;
int i;
#if PG_VERSION_NUM >= 120000
TupleTableSlot *slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsMinimalTuple);
#else
TupleTableSlot *slot = MakeSingleTupleTableSlot(buildstate->tupdesc);
#endif
TupleDesc tupdesc = RelationGetDescr(index);
GetNextTuple(buildstate->sortstate, tupdesc, slot, &itup, &list);
for (i = 0; i < buildstate->centers->length; i++)
{
buf = IvfflatNewBuffer(index, forkNum);
IvfflatInitPage(index, &buf, &page, &state);
startPage = BufferGetBlockNumber(buf);
/* Get all tuples for list */
while (list == i)
{
/* Check for free space */
itemsz = MAXALIGN(IndexTupleSize(itup));
if (PageGetFreeSpace(page) < itemsz)
IvfflatAppendPage(index, &buf, &page, &state, forkNum);
/* Add the item */
if (PageAddItem(page, (Item) itup, itemsz, InvalidOffsetNumber, false, false) == InvalidOffsetNumber)
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
buildstate->indtuples += 1;
GetNextTuple(buildstate->sortstate, tupdesc, slot, &itup, &list);
}
insertPage = BufferGetBlockNumber(buf);
IvfflatCommitBuffer(buf, state);
/* Set the start and insert pages */
IvfflatUpdateList(index, state, buildstate->listInfo[i], insertPage, startPage, forkNum);
}
}
/*
* Initialize the build state
*/
static void
InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo)
{
buildstate->heap = heap;
buildstate->index = index;
buildstate->indexInfo = indexInfo;
buildstate->lists = IvfflatGetLists(index);
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
/* Require column to have dimensions to be indexed */
if (buildstate->dimensions < 0)
elog(ERROR, "column does not have dimensions");
buildstate->reltuples = 0;
buildstate->indtuples = 0;
/* Get support functions */
buildstate->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
buildstate->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
buildstate->collation = index->rd_indcollation[0];
/* Create tuple description for sorting */
#if PG_VERSION_NUM >= 120000
buildstate->tupdesc = CreateTemplateTupleDesc(4);
#else
buildstate->tupdesc = CreateTemplateTupleDesc(4, false);
#endif
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 1, "list", INT4OID, -1, 0);
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 2, "blkno", INT4OID, -1, 0);
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 3, "offset", INT4OID, -1, 0);
#if PG_VERSION_NUM >= 110000
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 4, "vector", RelationGetDescr(index)->attrs[0].atttypid, -1, 0);
#else
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 4, "vector", RelationGetDescr(index)->attrs[0]->atttypid, -1, 0);
#endif
#if PG_VERSION_NUM >= 120000
buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual);
#else
buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc);
#endif
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions);
buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists);
/* Reuse for each tuple */
buildstate->normvec = InitVector(buildstate->dimensions);
}
/*
* Free resources
*/
static void
FreeBuildState(IvfflatBuildState * buildstate)
{
pfree(buildstate->centers);
pfree(buildstate->listInfo);
pfree(buildstate->normvec);
}
/*
* Compute centers
*/
static void
ComputeCenters(IvfflatBuildState * buildstate)
{
int numSamples;
/* Target 50 samples per list, with at least 10000 samples */
/* The number of samples has a large effect on index build time */
numSamples = buildstate->lists * 50;
if (numSamples < 10000)
numSamples = 10000;
/* Sample samples */
buildstate->samples = VectorArrayInit(numSamples, buildstate->dimensions);
if (buildstate->heap != NULL)
SampleRows(buildstate);
/* Calculate centers */
IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers);
/* Free samples before we allocate more memory */
pfree(buildstate->samples);
}
/*
* Create the metapage
*/
static void
CreateMetaPage(Relation index, int dimensions, int lists, ForkNumber forkNum)
{
Buffer buf;
Page page;
GenericXLogState *state;
IvfflatMetaPage metap;
buf = IvfflatNewBuffer(index, forkNum);
IvfflatInitPage(index, &buf, &page, &state);
/* Set metapage data */
metap = IvfflatPageGetMeta(page);
metap->magicNumber = IVFFLAT_MAGIC_NUMBER;
metap->version = IVFFLAT_VERSION;
metap->dimensions = dimensions;
metap->lists = lists;
((PageHeader) page)->pd_lower =
((char *) metap + sizeof(IvfflatMetaPageData)) - (char *) page;
IvfflatCommitBuffer(buf, state);
}
/*
* Create list pages
*/
static void
CreateListPages(Relation index, VectorArray centers, int dimensions,
int lists, ForkNumber forkNum, ListInfo * *listInfo)
{
int i;
Buffer buf;
Page page;
GenericXLogState *state;
OffsetNumber offno;
Size itemsz;
IvfflatList list;
itemsz = MAXALIGN(IVFFLAT_LIST_SIZE(dimensions));
list = palloc(itemsz);
buf = IvfflatNewBuffer(index, forkNum);
IvfflatInitPage(index, &buf, &page, &state);
for (i = 0; i < lists; i++)
{
/* Load list */
list->startPage = InvalidBlockNumber;
list->insertPage = InvalidBlockNumber;
memcpy(&list->center, VectorArrayGet(centers, i), VECTOR_SIZE(dimensions));
/* Ensure free space */
if (PageGetFreeSpace(page) < itemsz)
IvfflatAppendPage(index, &buf, &page, &state, forkNum);
/* Add the item */
offno = PageAddItem(page, (Item) list, itemsz, InvalidOffsetNumber, false, false);
if (offno == InvalidOffsetNumber)
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
/* Save location info */
(*listInfo)[i].blkno = BufferGetBlockNumber(buf);
(*listInfo)[i].offno = offno;
}
IvfflatCommitBuffer(buf, state);
pfree(list);
}
/*
* Create entry pages
*/
static void
CreateEntryPages(IvfflatBuildState * buildstate, ForkNumber forkNum)
{
AttrNumber attNums[] = {1};
Oid sortOperators[] = {Float8LessOperator};
Oid sortCollations[] = {InvalidOid};
bool nullsFirstFlags[] = {false};
#if PG_VERSION_NUM >= 110000
buildstate->sortstate = tuplesort_begin_heap(buildstate->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, maintenance_work_mem, NULL, false);
#else
buildstate->sortstate = tuplesort_begin_heap(buildstate->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, maintenance_work_mem, false);
#endif
/* Add tuples to sort */
if (buildstate->heap != NULL)
{
#if PG_VERSION_NUM >= 120000
buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, true, BuildCallback, (void *) buildstate, NULL);
#elif PG_VERSION_NUM >= 110000
buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, BuildCallback, (void *) buildstate, NULL);
#else
buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, BuildCallback, (void *) buildstate);
#endif
}
/* Sort and insert */
tuplesort_performsort(buildstate->sortstate);
InsertTuples(buildstate->index, buildstate, forkNum);
tuplesort_end(buildstate->sortstate);
}
/*
* Build the index
*/
static void
BuildIndex(Relation heap, Relation index, IndexInfo *indexInfo,
IvfflatBuildState * buildstate, ForkNumber forkNum)
{
InitBuildState(buildstate, heap, index, indexInfo);
ComputeCenters(buildstate);
/* Create pages */
CreateMetaPage(index, buildstate->dimensions, buildstate->lists, forkNum);
CreateListPages(index, buildstate->centers, buildstate->dimensions, buildstate->lists, forkNum, &buildstate->listInfo);
CreateEntryPages(buildstate, forkNum);
FreeBuildState(buildstate);
}
/*
* Build the index for a logged table
*/
IndexBuildResult *
ivfflatbuild(Relation heap, Relation index, IndexInfo *indexInfo)
{
IndexBuildResult *result;
IvfflatBuildState buildstate;
BuildIndex(heap, index, indexInfo, &buildstate, MAIN_FORKNUM);
result = (IndexBuildResult *) palloc(sizeof(IndexBuildResult));
result->heap_tuples = buildstate.reltuples;
result->index_tuples = buildstate.indtuples;
return result;
}
/*
* Build the index for an unlogged table
*/
void
ivfflatbuildempty(Relation index)
{
IndexInfo *indexInfo = BuildIndexInfo(index);
IvfflatBuildState buildstate;
BuildIndex(NULL, index, indexInfo, &buildstate, INIT_FORKNUM);
}

168
ivfflat.c Normal file
View File

@@ -0,0 +1,168 @@
#include "postgres.h"
#include "access/amapi.h"
#include "commands/vacuum.h"
#include "ivfflat.h"
#include "utils/guc.h"
#include "utils/selfuncs.h"
static relopt_kind ivfflat_relopt_kind;
/*
* Initialize index options and variables
*/
void
_PG_init(void)
{
ivfflat_relopt_kind = add_reloption_kind();
add_int_reloption(ivfflat_relopt_kind, "lists", "Number of inverted lists",
IVFFLAT_DEFAULT_LISTS, 1, IVFFLAT_MAX_LISTS
#if PG_VERSION_NUM >= 130000
,AccessExclusiveLock
#endif
);
DefineCustomIntVariable("ivfflat.probes", "Sets the number of probes",
"Valid range is 1..lists.", &ivfflat_probes,
1, 1, IVFFLAT_MAX_LISTS, PGC_USERSET, 0, NULL, NULL, NULL);
}
/*
* Estimate the cost of an index scan
*
* TODO Improve cost estimation
*/
static void
ivfflatcostestimate(PlannerInfo *root, IndexPath *path, double loop_count,
Cost *indexStartupCost, Cost *indexTotalCost,
Selectivity *indexSelectivity, double *indexCorrelation
#if PG_VERSION_NUM >= 100000
,double *indexPages
#endif
)
{
GenericCosts costs;
#if PG_VERSION_NUM < 120000
List *qinfos;
qinfos = deconstruct_indexquals(path);
#endif
MemSet(&costs, 0, sizeof(costs));
#if PG_VERSION_NUM >= 120000
genericcostestimate(root, path, loop_count, &costs);
#else
genericcostestimate(root, path, loop_count, qinfos, &costs);
#endif
*indexStartupCost = costs.indexStartupCost;
*indexTotalCost = costs.indexTotalCost;
*indexSelectivity = costs.indexSelectivity;
*indexCorrelation = costs.indexCorrelation;
#if PG_VERSION_NUM >= 100000
*indexPages = costs.numIndexPages;
#endif
}
/*
* Parse and validate the reloptions
*/
static bytea *
ivfflatoptions(Datum reloptions, bool validate)
{
static const relopt_parse_elt tab[] = {
{"lists", RELOPT_TYPE_INT, offsetof(IvfflatOptions, lists)},
};
#if PG_VERSION_NUM >= 130000
return (bytea *) build_reloptions(reloptions, validate,
ivfflat_relopt_kind,
sizeof(IvfflatOptions),
tab, lengthof(tab));
#else
relopt_value *options;
int numoptions;
IvfflatOptions *rdopts;
options = parseRelOptions(reloptions, validate, ivfflat_relopt_kind, &numoptions);
rdopts = allocateReloptStruct(sizeof(IvfflatOptions), options, numoptions);
fillRelOptions((void *) rdopts, sizeof(IvfflatOptions), options, numoptions,
validate, tab, lengthof(tab));
return (bytea *) rdopts;
#endif
}
/*
* Validate catalog entries for the specified operator class
*/
static bool
ivfflatvalidate(Oid opclassoid)
{
return true;
}
PG_FUNCTION_INFO_V1(ivfflathandler);
Datum
ivfflathandler(PG_FUNCTION_ARGS)
{
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
amroutine->amstrategies = 0;
amroutine->amsupport = 4;
#if PG_VERSION_NUM >= 130000
amroutine->amoptsprocnum = 0;
#endif
amroutine->amcanorder = false;
amroutine->amcanorderbyop = true;
amroutine->amcanbackward = false; /* can change direction mid-scan */
amroutine->amcanunique = false;
amroutine->amcanmulticol = false;
amroutine->amoptionalkey = true;
amroutine->amsearcharray = false;
amroutine->amsearchnulls = false;
amroutine->amstorage = false;
amroutine->amclusterable = false;
amroutine->ampredlocks = false;
#if PG_VERSION_NUM >= 100000
amroutine->amcanparallel = false;
#endif
#if PG_VERSION_NUM >= 110000
amroutine->amcaninclude = false;
#endif
#if PG_VERSION_NUM >= 130000
amroutine->amusemaintenanceworkmem = false; /* not used during VACUUM */
amroutine->amparallelvacuumoptions = VACUUM_OPTION_NO_PARALLEL; /* TODO support parallel */
#endif
amroutine->amkeytype = InvalidOid;
amroutine->ambuild = ivfflatbuild;
amroutine->ambuildempty = ivfflatbuildempty;
amroutine->aminsert = ivfflatinsert;
amroutine->ambulkdelete = ivfflatbulkdelete;
amroutine->amvacuumcleanup = ivfflatvacuumcleanup;
amroutine->amcanreturn = NULL;
amroutine->amcostestimate = ivfflatcostestimate;
amroutine->amoptions = ivfflatoptions;
amroutine->amproperty = NULL; /* TODO AMPROP_DISTANCE_ORDERABLE */
#if PG_VERSION_NUM >= 120000
amroutine->ambuildphasename = NULL;
#endif
amroutine->amvalidate = ivfflatvalidate;
amroutine->ambeginscan = ivfflatbeginscan;
amroutine->amrescan = ivfflatrescan;
amroutine->amgettuple = ivfflatgettuple;
amroutine->amgetbitmap = NULL;
amroutine->amendscan = ivfflatendscan;
amroutine->ammarkpos = NULL;
amroutine->amrestrpos = NULL;
#if PG_VERSION_NUM >= 100000
amroutine->amestimateparallelscan = NULL;
amroutine->aminitparallelscan = NULL;
amroutine->amparallelrescan = NULL;
#endif
PG_RETURN_POINTER(amroutine);
}

193
ivfflat.h Normal file
View File

@@ -0,0 +1,193 @@
#ifndef IVFFLAT_H
#define IVFFLAT_H
#include "postgres.h"
#include "access/generic_xlog.h"
#include "access/reloptions.h"
#include "nodes/execnodes.h"
#include "utils/sampling.h"
#include "utils/tuplesort.h"
#include "vector.h"
/* Support functions */
#define IVFFLAT_DISTANCE_PROC 1
#define IVFFLAT_NORM_PROC 2
#define IVFFLAT_KMEANS_DISTANCE_PROC 3
#define IVFFLAT_KMEANS_NORM_PROC 4
#define IVFFLAT_VERSION 1
#define IVFFLAT_MAGIC_NUMBER 0x14FF1A7
#define IVFFLAT_PAGE_ID 0xFF84
/* Preserved page numbers */
#define IVFFLAT_METAPAGE_BLKNO 0
#define IVFFLAT_HEAD_BLKNO 1 /* first list page */
#define IVFFLAT_DEFAULT_LISTS 100
#define IVFFLAT_MAX_LISTS 32768
#define IVFFLAT_LIST_SIZE(_dim) (offsetof(IvfflatListData, center) + VECTOR_SIZE(_dim))
#define IvfflatPageGetOpaque(page) ((IvfflatPageOpaque) PageGetSpecialPointer(page))
#define IvfflatPageGetMeta(page) ((IvfflatMetaPageData *) PageGetContents(page))
#if PG_VERSION_NUM < 100000
#define ItemPointerGetBlockNumberNoCheck ItemPointerGetBlockNumber
#define ItemPointerGetOffsetNumberNoCheck ItemPointerGetOffsetNumber
#endif
/* Variables */
int ivfflat_probes;
typedef struct VectorArrayData
{
int length;
int maxlen;
int dim;
Vector items[FLEXIBLE_ARRAY_MEMBER];
} VectorArrayData;
typedef VectorArrayData * VectorArray;
typedef struct ListInfo
{
BlockNumber blkno;
OffsetNumber offno;
} ListInfo;
/* IVFFlat index options */
typedef struct IvfflatOptions
{
int32 vl_len_; /* varlena header (do not touch directly!) */
int lists; /* number of lists */
} IvfflatOptions;
typedef struct IvfflatBuildState
{
/* Info */
Relation heap;
Relation index;
IndexInfo *indexInfo;
/* Settings */
int dimensions;
int lists;
/* Statistics */
double indtuples;
double reltuples;
/* Support functions */
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
Oid collation;
/* Variables */
VectorArray samples;
VectorArray centers;
ListInfo *listInfo;
Vector *normvec;
/* Sampling */
BlockSamplerData bs;
ReservoirStateData rstate;
int rowstoskip;
/* Sorting */
Tuplesortstate *sortstate;
TupleDesc tupdesc;
TupleTableSlot *slot;
} IvfflatBuildState;
typedef struct IvfflatMetaPageData
{
uint32 magicNumber;
uint32 version;
uint16 dimensions;
uint16 lists;
} IvfflatMetaPageData;
typedef IvfflatMetaPageData * IvfflatMetaPage;
typedef struct IvfflatPageOpaqueData
{
BlockNumber nextblkno;
uint16 unused;
uint16 page_id; /* for identification of IVFFlat indexes */
} IvfflatPageOpaqueData;
typedef IvfflatPageOpaqueData * IvfflatPageOpaque;
typedef struct IvfflatListData
{
BlockNumber startPage;
BlockNumber insertPage;
Vector center;
} IvfflatListData;
typedef IvfflatListData * IvfflatList;
typedef struct IvfflatScanList
{
BlockNumber startPage;
double distance;
} IvfflatScanList;
typedef struct IvfflatScanOpaqueData
{
int probes;
bool first;
Buffer buf;
/* Sorting */
Tuplesortstate *sortstate;
TupleDesc tupdesc;
TupleTableSlot *slot;
bool isnull;
/* Support functions */
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
Oid collation;
IvfflatScanList lists[FLEXIBLE_ARRAY_MEMBER]; /* must come last */
} IvfflatScanOpaqueData;
typedef IvfflatScanOpaqueData * IvfflatScanOpaque;
#define VECTOR_ARRAY_SIZE(_length, _dim) (offsetof(VectorArrayData, items) + _length * VECTOR_SIZE(_dim))
#define VECTOR_ARRAY_OFFSET(_arr, _offset) ((char*) _arr + offsetof(VectorArrayData, items) + (_offset) * VECTOR_SIZE(_arr->dim))
#define VectorArrayGet(_arr, _offset) ((Vector *) VECTOR_ARRAY_OFFSET(_arr, _offset))
#define VectorArraySet(_arr, _offset, _val) (memcpy(VECTOR_ARRAY_OFFSET(_arr, _offset), _val, VECTOR_SIZE(_arr->dim)))
/* Methods */
void _PG_init(void);
VectorArray VectorArrayInit(int maxlen, int dimensions);
void PrintVectorArray(char *msg, VectorArray arr);
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers);
FmgrInfo *IvfflatOptionalProcInfo(Relation rel, uint16 procnum);
bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result);
int IvfflatGetLists(Relation index);
void IvfflatUpdateList(Relation index, GenericXLogState *state, ListInfo listInfo, BlockNumber insertPage, BlockNumber startPage, ForkNumber forkNum);
void IvfflatCommitBuffer(Buffer buf, GenericXLogState *state);
void IvfflatAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, ForkNumber forkNum);
Buffer IvfflatNewBuffer(Relation index, ForkNumber forkNum);
void IvfflatInitPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state);
/* Index access methods */
IndexBuildResult *ivfflatbuild(Relation heap, Relation index, IndexInfo *indexInfo);
void ivfflatbuildempty(Relation index);
bool ivfflatinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heap, IndexUniqueCheck checkUnique
#if PG_VERSION_NUM >= 100000
,IndexInfo *indexInfo
#endif
);
IndexBulkDeleteResult *ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state);
IndexBulkDeleteResult *ivfflatvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats);
IndexScanDesc ivfflatbeginscan(Relation index, int nkeys, int norderbys);
void ivfflatrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys);
bool ivfflatgettuple(IndexScanDesc scan, ScanDirection dir);
void ivfflatendscan(IndexScanDesc scan);
#endif

163
ivfinsert.c Normal file
View File

@@ -0,0 +1,163 @@
#include "postgres.h"
#include <float.h>
#include "ivfflat.h"
#include "storage/bufmgr.h"
/*
* Find the list that minimizes the distance function
*/
static void
FindInsertPage(Relation rel, Datum *values, BlockNumber *insertPage, ListInfo * listInfo)
{
Buffer cbuf;
Page cpage;
IvfflatList list;
double distance;
double minDistance = DBL_MAX;
BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO;
FmgrInfo *procinfo;
Oid collation;
OffsetNumber offno;
OffsetNumber maxoffno;
procinfo = index_getprocinfo(rel, 1, IVFFLAT_DISTANCE_PROC);
collation = rel->rd_indcollation[0];
/* Search all list pages */
while (BlockNumberIsValid(nextblkno))
{
cbuf = ReadBuffer(rel, nextblkno);
LockBuffer(cbuf, BUFFER_LOCK_SHARE);
cpage = BufferGetPage(cbuf);
maxoffno = PageGetMaxOffsetNumber(cpage);
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, offno));
distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, values[0], PointerGetDatum(&list->center)));
if (distance < minDistance)
{
*insertPage = list->insertPage;
listInfo->blkno = nextblkno;
listInfo->offno = offno;
minDistance = distance;
}
}
nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno;
UnlockReleaseBuffer(cbuf);
}
}
/*
* Prepare to insert an index tuple
*/
static void
LoadInsertPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, BlockNumber insertPage)
{
*buf = ReadBuffer(index, insertPage);
LockBuffer(*buf, BUFFER_LOCK_EXCLUSIVE);
*state = GenericXLogStart(index);
*page = GenericXLogRegisterBuffer(*state, *buf, 0);
}
/*
* Insert a tuple into the index
*/
static void
InsertTuple(Relation rel, IndexTuple itup, Relation heapRel, Datum *values)
{
Buffer buf;
Page page;
GenericXLogState *state;
Size itemsz;
BlockNumber insertPage = InvalidBlockNumber;
ListInfo listInfo;
bool newPage = false;
/* Find the insert page - sets the page and list info */
FindInsertPage(rel, values, &insertPage, &listInfo);
Assert(BlockNumberIsValid(insertPage));
itemsz = MAXALIGN(IndexTupleSize(itup));
Assert(itemsz <= BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(IvfflatPageOpaqueData)));
LoadInsertPage(rel, &buf, &page, &state, insertPage);
/* Find a page to insert the item */
while (PageGetFreeSpace(page) < itemsz)
{
insertPage = IvfflatPageGetOpaque(page)->nextblkno;
if (BlockNumberIsValid(insertPage))
{
/* Move to next page */
GenericXLogAbort(state);
UnlockReleaseBuffer(buf);
LoadInsertPage(rel, &buf, &page, &state, insertPage);
}
else
{
/* Add a new page */
IvfflatAppendPage(rel, &buf, &page, &state, MAIN_FORKNUM);
insertPage = BufferGetBlockNumber(buf);
newPage = true;
}
}
/* Add to next offset */
if (PageAddItem(page, (Item) itup, itemsz, InvalidOffsetNumber, false, false) == InvalidOffsetNumber)
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(rel));
IvfflatCommitBuffer(buf, state);
/* Update the insert page */
if (newPage)
IvfflatUpdateList(rel, state, listInfo, insertPage, InvalidBlockNumber, MAIN_FORKNUM);
}
/*
* Insert a tuple into the index
*/
bool
ivfflatinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid,
Relation heap, IndexUniqueCheck checkUnique
#if PG_VERSION_NUM >= 100000
,IndexInfo *indexInfo
#endif
)
{
IndexTuple itup;
Datum value;
FmgrInfo *normprocinfo;
if (isnull[0])
return false;
value = values[0];
/* Normalize if needed */
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
if (normprocinfo != NULL)
{
if (!IvfflatNormValue(normprocinfo, index->rd_indcollation[0], &value, NULL))
return false;
}
itup = index_form_tuple(RelationGetDescr(index), &value, isnull);
itup->t_tid = *heap_tid;
InsertTuple(index, itup, heap, &value);
pfree(itup);
/* Clean up if we allocated a new value */
if (value != values[0])
pfree(DatumGetPointer(value));
return false;
}

479
ivfkmeans.c Normal file
View File

@@ -0,0 +1,479 @@
#include "postgres.h"
#include <float.h>
#include "ivfflat.h"
#include "miscadmin.h"
/*
* Initialize with kmeans++
*
* https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf
*/
static void
InitCenters(Relation index, VectorArray samples, VectorArray centers, double *lowerBound)
{
FmgrInfo *procinfo;
Oid collation;
int i;
int j;
double distance;
double sum;
double choice;
Vector *vec;
double *weight = palloc(samples->length * sizeof(double));
int numCenters = centers->maxlen;
int numSamples = samples->length;
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
collation = index->rd_indcollation[0];
/* Choose an initial center uniformly at random */
VectorArraySet(centers, 0, VectorArrayGet(samples, random() % samples->length));
centers->length++;
for (j = 0; j < numSamples; j++)
weight[j] = DBL_MAX;
for (i = 0; i < numCenters; i++)
{
CHECK_FOR_INTERRUPTS();
sum = 0.0;
for (j = 0; j < numSamples; j++)
{
vec = VectorArrayGet(samples, j);
/* Only need to compute distance for new center */
/* TODO Use triangle inequality to reduce distance calculations */
distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, i))));
/* Set lower bound */
lowerBound[j * numCenters + i] = distance;
/* Use distance squared for weighted probability distribution */
distance *= distance;
if (distance < weight[j])
weight[j] = distance;
sum += weight[j];
}
/* Only compute lower bound on last iteration */
if (i + 1 == numCenters)
break;
/* Choose new center using weighted probability distribution. */
choice = sum * (((double) random()) / MAX_RANDOM_VALUE);
for (j = 0; j < numSamples - 1; j++)
{
choice -= weight[j];
if (choice <= 0)
break;
}
VectorArraySet(centers, i + 1, VectorArrayGet(samples, j));
centers->length++;
}
pfree(weight);
}
/*
* Apply norm to vector
*/
static inline void
ApplyNorm(FmgrInfo *normprocinfo, Oid collation, Vector * vec)
{
int i;
double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(vec)));
/* TODO Handle zero norm */
if (norm > 0)
{
for (i = 0; i < vec->dim; i++)
vec->x[i] /= norm;
}
}
/*
* Compare vectors
*/
static int
CompareVectors(const void *a, const void *b)
{
return vector_cmp_internal((Vector *) a, (Vector *) b);
}
/*
* Quick approach if we have little data
*/
static void
QuickCenters(Relation index, VectorArray samples, VectorArray centers)
{
int i;
int j;
Vector *vec;
int dimensions = centers->dim;
Oid collation = index->rd_indcollation[0];
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
/* Copy existing vectors while avoiding duplicates */
qsort(samples->items, samples->length, VECTOR_SIZE(samples->dim), CompareVectors);
for (i = 0; i < samples->length; i++)
{
vec = VectorArrayGet(samples, i);
if (i == 0 || CompareVectors(vec, VectorArrayGet(samples, i - 1)) != 0)
{
VectorArraySet(centers, centers->length, vec);
centers->length++;
}
}
/* Fill remaining with random data */
while (centers->length < centers->maxlen)
{
vec = VectorArrayGet(centers, centers->length);
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
vec->dim = dimensions;
for (j = 0; j < dimensions; j++)
vec->x[j] = ((double) random()) / MAX_RANDOM_VALUE;
/* Normalize if needed (only needed for random centers) */
if (normprocinfo != NULL)
ApplyNorm(normprocinfo, collation, vec);
centers->length++;
}
}
/*
* Use Elkan for performance. This requires distance function to satisfy triangle inequality.
*
* We use L2 distance for L2 (not L2 squared like index scan)
* and angular distance for inner product and cosine distance
*
* https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf
*/
static void
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
{
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
Oid collation;
Vector *vec;
Vector *newCenter;
int iteration;
int j;
int k;
int dimensions = centers->dim;
int numCenters = centers->maxlen;
int numSamples = samples->length;
VectorArray newCenters;
int *centerCounts;
int *closestCenters;
double *lowerBound;
double *upperBound;
double *s;
double *halfcdist;
double *newcdist;
int changes;
double minDistance;
int closestCenter;
double distance;
bool rj;
bool rjreset;
double dxcx;
double dxc;
/* Set support functions */
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
collation = index->rd_indcollation[0];
/* Allocate space */
centerCounts = palloc(sizeof(int) * numCenters);
closestCenters = palloc(sizeof(int) * numSamples);
lowerBound = palloc(sizeof(double) * numSamples * numCenters);
upperBound = palloc(sizeof(double) * numSamples);
s = palloc(sizeof(double) * numCenters);
halfcdist = palloc(sizeof(double) * numCenters * numCenters);
newcdist = palloc(sizeof(double) * numCenters);
newCenters = VectorArrayInit(numCenters, dimensions);
for (j = 0; j < numCenters; j++)
{
vec = VectorArrayGet(newCenters, j);
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
vec->dim = dimensions;
}
/* Pick initial centers */
InitCenters(index, samples, centers, lowerBound);
/* Assign each x to its closest initial center c(x) = argmin d(x,c) */
for (j = 0; j < numSamples; j++)
{
minDistance = DBL_MAX;
closestCenter = -1;
vec = VectorArrayGet(samples, j);
/* Find closest center */
for (k = 0; k < numCenters; k++)
{
/* TODO Use Lemma 1 in k-means++ initialization */
distance = lowerBound[j * numCenters + k];
if (distance < minDistance)
{
minDistance = distance;
closestCenter = k;
}
}
upperBound[j] = minDistance;
closestCenters[j] = closestCenter;
}
/* Give 500 iterations to converge */
for (iteration = 0; iteration < 500; iteration++)
{
/* Can take a while, so ensure we can interrupt */
CHECK_FOR_INTERRUPTS();
changes = 0;
/* Step 1: For all centers, compute distance */
for (j = 0; j < numCenters; j++)
{
vec = VectorArrayGet(centers, j);
for (k = j + 1; k < numCenters; k++)
{
distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k))));
halfcdist[j * numCenters + k] = distance;
halfcdist[k * numCenters + j] = distance;
}
}
/* For all centers c, compute s(c) */
for (j = 0; j < numCenters; j++)
{
minDistance = DBL_MAX;
for (k = 0; k < numCenters; k++)
{
if (j == k)
continue;
distance = halfcdist[j * numCenters + k];
if (distance < minDistance)
minDistance = distance;
}
s[j] = minDistance;
}
rjreset = iteration != 0;
for (j = 0; j < numSamples; j++)
{
/* Step 2: Identify all points x such that u(x) <= s(c(x)) */
if (upperBound[j] <= s[closestCenters[j]])
continue;
rj = rjreset;
for (k = 0; k < numCenters; k++)
{
/* Step 3: For all remaining points x and centers c */
if (k == closestCenters[j])
continue;
if (upperBound[j] <= lowerBound[j * numCenters + k])
continue;
if (upperBound[j] <= halfcdist[closestCenters[j] * numCenters + k])
continue;
vec = VectorArrayGet(samples, j);
/* Step 3a */
if (rj)
{
dxcx = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, closestCenters[j]))));
/* d(x,c(x)) computed, which is a form of d(x,c) */
lowerBound[j * numCenters + closestCenters[j]] = dxcx;
upperBound[j] = dxcx;
rj = false;
}
else
dxcx = upperBound[j];
/* Step 3b */
if (dxcx > lowerBound[j * numCenters + k] || dxcx > halfcdist[closestCenters[j] * numCenters + k])
{
dxc = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k))));
/* d(x,c) calculated */
lowerBound[j * numCenters + k] = dxc;
if (dxc < dxcx)
{
closestCenters[j] = k;
/* c(x) changed */
upperBound[j] = dxc;
changes++;
}
}
}
}
/* Step 4: For each center c, let m(c) be mean of all points assigned */
for (j = 0; j < numCenters; j++)
{
vec = VectorArrayGet(newCenters, j);
for (k = 0; k < dimensions; k++)
vec->x[k] = 0.0;
centerCounts[j] = 0;
}
for (j = 0; j < numSamples; j++)
{
vec = VectorArrayGet(samples, j);
closestCenter = closestCenters[j];
/* Increment sum and count of closest center */
newCenter = VectorArrayGet(newCenters, closestCenter);
for (k = 0; k < dimensions; k++)
newCenter->x[k] += vec->x[k];
centerCounts[closestCenter] += 1;
}
for (j = 0; j < numCenters; j++)
{
vec = VectorArrayGet(newCenters, j);
if (centerCounts[j] > 0)
{
for (k = 0; k < dimensions; k++)
vec->x[k] /= centerCounts[j];
}
else
{
/* TODO Handle empty centers properly */
for (k = 0; k < dimensions; k++)
vec->x[k] = ((double) random()) / MAX_RANDOM_VALUE;
}
/* Normalize if needed */
if (normprocinfo != NULL)
ApplyNorm(normprocinfo, collation, vec);
}
/* Step 5 */
for (j = 0; j < numCenters; j++)
newcdist[j] = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(VectorArrayGet(centers, j)), PointerGetDatum(VectorArrayGet(newCenters, j))));
for (j = 0; j < numSamples; j++)
{
for (k = 0; k < numCenters; k++)
{
distance = lowerBound[j * numCenters + k] - newcdist[k];
if (distance < 0)
distance = 0;
lowerBound[j * numCenters + k] = distance;
}
}
/* Step 6 */
/* We reset r(x) before Step 3 in the next iteration */
for (j = 0; j < numSamples; j++)
upperBound[j] += newcdist[closestCenters[j]];
/* Step 7 */
for (j = 0; j < numCenters; j++)
memcpy(VectorArrayGet(centers, j), VectorArrayGet(newCenters, j), VECTOR_SIZE(dimensions));
if (changes == 0 && iteration != 0)
break;
}
pfree(newCenters);
pfree(centerCounts);
pfree(closestCenters);
pfree(lowerBound);
pfree(upperBound);
pfree(s);
pfree(halfcdist);
pfree(newcdist);
}
/*
* Detect issues with centers
*/
static void
CheckCenters(Relation index, VectorArray centers)
{
FmgrInfo *normprocinfo;
Oid collation;
int i;
double norm;
if (centers->length != centers->maxlen)
elog(ERROR, "Not enough centers. Please report a bug.");
/* Ensure no duplicate centers */
/* Fine to sort in-place */
qsort(centers->items, centers->length, VECTOR_SIZE(centers->dim), CompareVectors);
for (i = 1; i < centers->length; i++)
{
if (CompareVectors(VectorArrayGet(centers, i), VectorArrayGet(centers, i - 1)) == 0)
elog(ERROR, "Duplicate centers detected. Please report a bug.");
}
/* Ensure no zero vectors for cosine distance */
/* Check NORM_PROC instead of KMEANS_NORM_PROC */
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
if (normprocinfo != NULL)
{
collation = index->rd_indcollation[0];
for (i = 0; i < centers->length; i++)
{
norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(VectorArrayGet(centers, i))));
if (norm == 0)
elog(ERROR, "Zero norm detected. Please report a bug.");
}
}
}
/*
* Perform naive k-means centering
* We use spherical k-means for inner product and cosine
*/
void
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers)
{
if (samples->length <= centers->maxlen)
QuickCenters(index, samples, centers);
else
ElkanKmeans(index, samples, centers);
CheckCenters(index, centers);
}

327
ivfscan.c Normal file
View File

@@ -0,0 +1,327 @@
#include "postgres.h"
#include "access/relscan.h"
#include "ivfflat.h"
#include "miscadmin.h"
#include "storage/bufmgr.h"
#if PG_VERSION_NUM >= 110000
#include "catalog/pg_operator_d.h"
#include "catalog/pg_type_d.h"
#else
#include "catalog/pg_operator.h"
#include "catalog/pg_type.h"
#endif
/*
* Compare list distances
*/
static int
CompareLists(const void *a, const void *b)
{
double diff = (((IvfflatScanList *) a)->distance - ((IvfflatScanList *) b)->distance);
if (diff > 0)
return 1;
if (diff < 0)
return -1;
return 0;
}
/*
* Get lists and sort by distance
*/
static void
GetScanLists(IndexScanDesc scan, Datum value)
{
Buffer cbuf;
Page cpage;
IvfflatList list;
OffsetNumber offno;
OffsetNumber maxoffno;
BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO;
int listCount = 0;
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
double distance;
/* Search all list pages */
while (BlockNumberIsValid(nextblkno))
{
cbuf = ReadBuffer(scan->indexRelation, nextblkno);
LockBuffer(cbuf, BUFFER_LOCK_SHARE);
cpage = BufferGetPage(cbuf);
maxoffno = PageGetMaxOffsetNumber(cpage);
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, offno));
/* Use procinfo from the index instead of scan key for performance */
distance = DatumGetFloat8(FunctionCall2Coll(so->procinfo, so->collation, PointerGetDatum(&list->center), value));
so->lists[listCount].startPage = list->startPage;
so->lists[listCount].distance = distance;
listCount++;
}
nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno;
UnlockReleaseBuffer(cbuf);
}
/* Sort by distance */
qsort(so->lists, listCount, sizeof(IvfflatScanList), CompareLists);
if (so->probes > listCount)
so->probes = listCount;
}
/*
* Get items
*/
static void
GetScanItems(IndexScanDesc scan, Datum value)
{
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
Buffer buf;
Page page;
IndexTuple itup;
BlockNumber searchPage;
OffsetNumber offno;
OffsetNumber maxoffno;
Datum datum;
bool isnull;
int i;
TupleDesc tupdesc = RelationGetDescr(scan->indexRelation);
#if PG_VERSION_NUM >= 120000
TupleTableSlot *slot = MakeSingleTupleTableSlot(so->tupdesc, &TTSOpsVirtual);
#else
TupleTableSlot *slot = MakeSingleTupleTableSlot(so->tupdesc);
#endif
/*
* Reuse same set of shared buffers for scan
*
* See postgres/src/backend/storage/buffer/README for description
*/
BufferAccessStrategy bas = GetAccessStrategy(BAS_BULKREAD);
/* Search closest probes lists */
for (i = 0; i < so->probes; i++)
{
searchPage = so->lists[i].startPage;
/* Search all entry pages for list */
while (BlockNumberIsValid(searchPage))
{
buf = ReadBufferExtended(scan->indexRelation, MAIN_FORKNUM, searchPage, RBM_NORMAL, bas);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
maxoffno = PageGetMaxOffsetNumber(page);
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
itup = (IndexTuple) PageGetItem(page, PageGetItemId(page, offno));
datum = index_getattr(itup, 1, tupdesc, &isnull);
/*
* Add virtual tuple
*
* Use procinfo from the index instead of scan key for
* performance
*/
ExecClearTuple(slot);
slot->tts_values[0] = FunctionCall2Coll(so->procinfo, so->collation, datum, value);
slot->tts_isnull[0] = false;
slot->tts_values[1] = Int32GetDatum((int) ItemPointerGetBlockNumberNoCheck(&itup->t_tid));
slot->tts_isnull[1] = false;
slot->tts_values[2] = Int32GetDatum((int) ItemPointerGetOffsetNumberNoCheck(&itup->t_tid));
slot->tts_isnull[2] = false;
slot->tts_values[3] = Int32GetDatum((int) searchPage);
slot->tts_isnull[3] = false;
ExecStoreVirtualTuple(slot);
tuplesort_puttupleslot(so->sortstate, slot);
}
searchPage = IvfflatPageGetOpaque(page)->nextblkno;
UnlockReleaseBuffer(buf);
}
}
}
/*
* Prepare for an index scan
*/
IndexScanDesc
ivfflatbeginscan(Relation index, int nkeys, int norderbys)
{
IndexScanDesc scan;
IvfflatScanOpaque so;
int lists;
AttrNumber attNums[] = {1};
Oid sortOperators[] = {Float8LessOperator};
Oid sortCollations[] = {InvalidOid};
bool nullsFirstFlags[] = {false};
scan = RelationGetIndexScan(index, nkeys, norderbys);
lists = IvfflatGetLists(scan->indexRelation);
so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + lists * sizeof(IvfflatScanList));
so->buf = InvalidBuffer;
so->first = true;
/* Set support functions */
so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
so->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
so->collation = index->rd_indcollation[0];
/* Create tuple description for sorting */
#if PG_VERSION_NUM >= 120000
so->tupdesc = CreateTemplateTupleDesc(4);
#else
so->tupdesc = CreateTemplateTupleDesc(4, false);
#endif
TupleDescInitEntry(so->tupdesc, (AttrNumber) 1, "distance", FLOAT8OID, -1, 0);
TupleDescInitEntry(so->tupdesc, (AttrNumber) 2, "blkno", INT4OID, -1, 0);
TupleDescInitEntry(so->tupdesc, (AttrNumber) 3, "offset", INT4OID, -1, 0);
TupleDescInitEntry(so->tupdesc, (AttrNumber) 4, "indexblkno", INT4OID, -1, 0);
/* Prep sort */
#if PG_VERSION_NUM >= 110000
so->sortstate = tuplesort_begin_heap(so->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, work_mem, NULL, false);
#else
so->sortstate = tuplesort_begin_heap(so->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, work_mem, false);
#endif
#if PG_VERSION_NUM >= 120000
so->slot = MakeSingleTupleTableSlot(so->tupdesc, &TTSOpsMinimalTuple);
#else
so->slot = MakeSingleTupleTableSlot(so->tupdesc);
#endif
scan->opaque = so;
return scan;
}
/*
* Start or restart an index scan
*/
void
ivfflatrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys)
{
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
#if PG_VERSION_NUM >= 130000
if (!so->first)
tuplesort_reset(so->sortstate);
#endif
so->first = true;
so->probes = ivfflat_probes;
if (keys && scan->numberOfKeys > 0)
memmove(scan->keyData, keys, scan->numberOfKeys * sizeof(ScanKeyData));
if (orderbys && scan->numberOfOrderBys > 0)
memmove(scan->orderByData, orderbys, scan->numberOfOrderBys * sizeof(ScanKeyData));
}
/*
* Fetch the next tuple in the given scan
*/
bool
ivfflatgettuple(IndexScanDesc scan, ScanDirection dir)
{
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
/*
* Index can be used to scan backward, but Postgres doesn't support
* backward scan on operators
*/
Assert(ScanDirectionIsForward(dir));
if (so->first)
{
Datum value;
/* No items will match if null */
if (scan->orderByData->sk_flags & SK_ISNULL)
return false;
value = scan->orderByData->sk_argument;
if (so->normprocinfo != NULL)
{
/* No items will match if normalization fails */
if (!IvfflatNormValue(so->normprocinfo, so->collation, &value, NULL))
return false;
}
GetScanLists(scan, value);
GetScanItems(scan, value);
tuplesort_performsort(so->sortstate);
so->first = false;
/* Clean up if we allocated a new value */
if (value != scan->orderByData->sk_argument)
pfree(DatumGetPointer(value));
}
#if PG_VERSION_NUM >= 100000
if (tuplesort_gettupleslot(so->sortstate, true, false, so->slot, NULL))
#else
if (tuplesort_gettupleslot(so->sortstate, true, so->slot, NULL))
#endif
{
BlockNumber blkno = DatumGetInt32(slot_getattr(so->slot, 2, &so->isnull));
OffsetNumber offset = DatumGetInt32(slot_getattr(so->slot, 3, &so->isnull));
BlockNumber indexblkno = DatumGetInt32(slot_getattr(so->slot, 4, &so->isnull));
#if PG_VERSION_NUM >= 120000
ItemPointerSet(&scan->xs_heaptid, blkno, offset);
#else
ItemPointerSet(&scan->xs_ctup.t_self, blkno, offset);
#endif
if (BufferIsValid(so->buf))
ReleaseBuffer(so->buf);
/*
* An index scan must maintain a pin on the index page holding the
* item last returned by amgettuple
*
* https://www.postgresql.org/docs/current/index-locking.html
*/
so->buf = ReadBuffer(scan->indexRelation, indexblkno);
scan->xs_recheckorderby = false;
return true;
}
return false;
}
/*
* End a scan and release resources
*/
void
ivfflatendscan(IndexScanDesc scan)
{
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
/* Release pin */
if (BufferIsValid(so->buf))
ReleaseBuffer(so->buf);
tuplesort_end(so->sortstate);
pfree(so);
scan->opaque = NULL;
}

176
ivfutils.c Normal file
View File

@@ -0,0 +1,176 @@
#include "postgres.h"
#include "ivfflat.h"
#include "storage/bufmgr.h"
#include "vector.h"
/*
* Allocate a vector array
*/
VectorArray
VectorArrayInit(int maxlen, int dimensions)
{
VectorArray res = palloc0(VECTOR_ARRAY_SIZE(maxlen, dimensions));
res->length = 0;
res->maxlen = maxlen;
res->dim = dimensions;
return res;
}
/*
* Print vector array - useful for debugging
*/
void
PrintVectorArray(char *msg, VectorArray arr)
{
int i;
for (i = 0; i < arr->length; i++)
PrintVector(msg, VectorArrayGet(arr, i));
}
/*
* Get the number of lists in the index
*/
int
IvfflatGetLists(Relation index)
{
IvfflatOptions *opts = (IvfflatOptions *) index->rd_options;
if (opts)
return opts->lists;
return IVFFLAT_DEFAULT_LISTS;
}
/*
* Get proc
*/
FmgrInfo *
IvfflatOptionalProcInfo(Relation rel, uint16 procnum)
{
if (!OidIsValid(index_getprocid(rel, 1, procnum)))
return NULL;
return index_getprocinfo(rel, 1, procnum);
}
/*
* Divide by the norm
*
* Returns false if value should not be indexed
*
* The caller needs to free the pointer stored in value
* if it's different than the original value
*/
bool
IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result)
{
Vector *v;
int i;
double norm;
norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
if (norm > 0)
{
v = (Vector *) DatumGetPointer(*value);
if (result == NULL)
result = InitVector(v->dim);
for (i = 0; i < v->dim; i++)
result->x[i] = v->x[i] / norm;
*value = PointerGetDatum(result);
return true;
}
return false;
}
/*
* New buffer
*/
Buffer
IvfflatNewBuffer(Relation index, ForkNumber forkNum)
{
Buffer buf = ReadBufferExtended(index, forkNum, P_NEW, RBM_NORMAL, NULL);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
return buf;
}
/*
* Init page
*/
void
IvfflatInitPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state)
{
*state = GenericXLogStart(index);
*page = GenericXLogRegisterBuffer(*state, *buf, GENERIC_XLOG_FULL_IMAGE);
PageInit(*page, BufferGetPageSize(*buf), sizeof(IvfflatPageOpaqueData));
IvfflatPageGetOpaque(*page)->nextblkno = InvalidBlockNumber;
IvfflatPageGetOpaque(*page)->page_id = IVFFLAT_PAGE_ID;
}
/*
* Commit buffer
*/
void
IvfflatCommitBuffer(Buffer buf, GenericXLogState *state)
{
MarkBufferDirty(buf);
GenericXLogFinish(state);
UnlockReleaseBuffer(buf);
}
/*
* Add a new page
*
* The order is very important!!
*/
void
IvfflatAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, ForkNumber forkNum)
{
Buffer prevbuf = *buf;
/* Get new buffer */
*buf = IvfflatNewBuffer(index, forkNum);
/* Update and commit previous buffer */
IvfflatPageGetOpaque(*page)->nextblkno = BufferGetBlockNumber(*buf);
IvfflatCommitBuffer(prevbuf, *state);
/* Init new page */
IvfflatInitPage(index, buf, page, state);
}
/*
* Update the start or insert page of a list
*/
void
IvfflatUpdateList(Relation index, GenericXLogState *state, ListInfo listInfo,
BlockNumber insertPage, BlockNumber startPage, ForkNumber forkNum)
{
Buffer buf;
Page page;
IvfflatList list;
buf = ReadBufferExtended(index, forkNum, listInfo.blkno, RBM_NORMAL, NULL);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
list = (IvfflatList) PageGetItem(page, PageGetItemId(page, listInfo.offno));
if (BlockNumberIsValid(insertPage))
list->insertPage = insertPage;
if (BlockNumberIsValid(startPage))
list->startPage = startPage;
/* Could only commit if changed, but extra complexity isn't needed */
IvfflatCommitBuffer(buf, state);
}

151
ivfvacuum.c Normal file
View File

@@ -0,0 +1,151 @@
#include "postgres.h"
#include "commands/vacuum.h"
#include "ivfflat.h"
#include "storage/bufmgr.h"
/*
* Bulk delete tuples from the index
*/
IndexBulkDeleteResult *
ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats,
IndexBulkDeleteCallback callback, void *callback_state)
{
Relation index = info->index;
Buffer cbuf;
Page cpage;
Buffer buf;
Page page;
IvfflatList list;
IndexTuple itup;
ItemPointer htup;
OffsetNumber deletable[MaxOffsetNumber];
int ndeletable;
OffsetNumber startPages[MaxOffsetNumber];
BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO;
BlockNumber searchPage;
BlockNumber insertPage;
GenericXLogState *state;
OffsetNumber coffno;
OffsetNumber cmaxoffno;
OffsetNumber offno;
OffsetNumber maxoffno;
ListInfo listInfo;
BufferAccessStrategy bas = GetAccessStrategy(BAS_BULKREAD);
if (stats == NULL)
stats = (IndexBulkDeleteResult *) palloc0(sizeof(IndexBulkDeleteResult));
/* Iterate over list pages */
while (BlockNumberIsValid(nextblkno))
{
cbuf = ReadBuffer(index, nextblkno);
LockBuffer(cbuf, BUFFER_LOCK_SHARE);
cpage = BufferGetPage(cbuf);
cmaxoffno = PageGetMaxOffsetNumber(cpage);
/* Iterate over lists */
for (coffno = FirstOffsetNumber; coffno <= cmaxoffno; coffno = OffsetNumberNext(coffno))
{
list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, coffno));
startPages[coffno - FirstOffsetNumber] = list->startPage;
}
listInfo.blkno = nextblkno;
nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno;
UnlockReleaseBuffer(cbuf);
for (coffno = FirstOffsetNumber; coffno <= cmaxoffno; coffno = OffsetNumberNext(coffno))
{
searchPage = startPages[coffno - FirstOffsetNumber];
insertPage = InvalidBlockNumber;
/* Iterate over entry pages */
while (BlockNumberIsValid(searchPage))
{
vacuum_delay_point();
buf = ReadBufferExtended(index, MAIN_FORKNUM, searchPage, RBM_NORMAL, bas);
/*
* ambulkdelete cannot delete entries from pages that are
* pinned by other backends
*
* https://www.postgresql.org/docs/current/index-locking.html
*/
LockBufferForCleanup(buf);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
maxoffno = PageGetMaxOffsetNumber(page);
ndeletable = 0;
/* Find deleted tuples */
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
itup = (IndexTuple) PageGetItem(page, PageGetItemId(page, offno));
htup = &(itup->t_tid);
if (callback(htup, callback_state))
{
deletable[ndeletable++] = offno;
stats->tuples_removed++;
}
else
stats->num_index_tuples++;
}
searchPage = IvfflatPageGetOpaque(page)->nextblkno;
if (ndeletable > 0)
{
/* Delete tuples */
PageIndexMultiDelete(page, deletable, ndeletable);
MarkBufferDirty(buf);
GenericXLogFinish(state);
/* Set to first free page */
if (!BlockNumberIsValid(insertPage))
insertPage = searchPage;
}
else
GenericXLogAbort(state);
UnlockReleaseBuffer(buf);
}
/*
* Update after all tuples deleted.
*
* We don't add or delete items from lists pages, so offset won't
* change.
*/
if (!BlockNumberIsValid(insertPage))
{
listInfo.offno = coffno;
IvfflatUpdateList(index, state, listInfo, insertPage, InvalidBlockNumber, MAIN_FORKNUM);
}
}
}
return stats;
}
/*
* Clean up after a VACUUM operation
*/
IndexBulkDeleteResult *
ivfflatvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats)
{
Relation rel = info->index;
if (stats == NULL)
return NULL;
stats->num_pages = RelationGetNumberOfBlocks(rel);
return stats;
}

12
sql/btree.sql Normal file
View File

@@ -0,0 +1,12 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t (val);
SELECT * FROM t WHERE val = '[1,2,3]';
SELECT * FROM t ORDER BY val LIMIT 1;
DROP TABLE t;

11
sql/cast.sql Normal file
View File

@@ -0,0 +1,11 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SELECT ARRAY[1,2,3]::vector;
SELECT ARRAY[1,2,3]::float4[]::vector;
SELECT ARRAY[1,2,3]::float8[]::vector;
SELECT '{NULL}'::real[]::vector;
SELECT '{NaN}'::real[]::vector;
SELECT '{Infinity}'::real[]::vector;
SELECT '{-Infinity}'::real[]::vector;
SELECT '{}'::real[]::vector;

18
sql/functions.sql Normal file
View File

@@ -0,0 +1,18 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SELECT '[1,2,3]'::vector + '[4,5,6]';
SELECT '[1,2,3]'::vector - '[4,5,6]';
SELECT vector_dims('[1,2,3]');
SELECT round(vector_norm('[1,1]')::numeric, 5);
SELECT round(l2_distance('[1,2]', '[0,0]')::numeric, 5);
SELECT l2_distance('[1,2]', '[3]');
SELECT inner_product('[1,2]', '[3,4]');
SELECT inner_product('[1,2]', '[3]');
SELECT round(cosine_distance('[1,2]', '[2,4]')::numeric, 5);
SELECT cosine_distance('[1,2]', '[0,0]');
SELECT cosine_distance('[1,2]', '[3]');

14
sql/ivfflat_cosine.sql Normal file
View File

@@ -0,0 +1,14 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING ivfflat (val vector_cosine_ops) WITH (lists = 1);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <=> '[3,3,3]';
SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector);
DROP TABLE t;

14
sql/ivfflat_ip.sql Normal file
View File

@@ -0,0 +1,14 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING ivfflat (val vector_ip_ops) WITH (lists = 1);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <#> '[3,3,3]';
SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector);
DROP TABLE t;

14
sql/ivfflat_l2.sql Normal file
View File

@@ -0,0 +1,14 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING ivfflat (val) WITH (lists = 1);
INSERT INTO t (val) VALUES ('[1,2,4]');
SELECT * FROM t ORDER BY val <-> '[3,3,3]';
SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector);
DROP TABLE t;

11
sql/ivfflat_options.sql Normal file
View File

@@ -0,0 +1,11 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SET enable_seqscan = off;
CREATE TABLE t (val vector(3));
CREATE INDEX ON t USING ivfflat (val) WITH (lists = 0);
CREATE INDEX ON t USING ivfflat (val) WITH (lists = 32769);
SHOW ivfflat.probes;
DROP TABLE t;

11
sql/ivfflat_unlogged.sql Normal file
View File

@@ -0,0 +1,11 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SET enable_seqscan = off;
CREATE UNLOGGED TABLE t (val vector(3));
INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL);
CREATE INDEX ON t USING ivfflat (val) WITH (lists = 1);
SELECT * FROM t ORDER BY val <-> '[3,3,3]';
DROP TABLE t;

15
sql/vector.sql Normal file
View File

@@ -0,0 +1,15 @@
SET client_min_messages = warning;
CREATE EXTENSION IF NOT EXISTS vector;
SELECT '[1,2,3]'::vector;
SELECT '[-1,2,3]'::vector;
SELECT '[hello,1]'::vector;
SELECT '[NaN,1]'::vector;
SELECT '[Infinity,1]'::vector;
SELECT '[-Infinity,1]'::vector;
SELECT '[1,2,3'::vector;
SELECT '[1,2,3]9'::vector;
SELECT '1,2,3'::vector;
SELECT '[]'::vector;
SELECT '[1,]'::vector;
SELECT '[1,2,3]'::vector(2);

80
t/001_wal.pl Normal file
View File

@@ -0,0 +1,80 @@
# Based on postgres/contrib/bloom/t/001_wal.pl
# Test generic xlog record work for ivfflat index replication.
use strict;
use warnings;
use PostgresNode;
use TestLib;
use Test::More tests => 31;
my $node_primary;
my $node_replica;
# Run few queries on both primary and replica and check their results match.
sub test_index_replay
{
my ($test_name) = @_;
# Wait for replica to catch up
my $applname = $node_replica->name;
my $caughtup_query =
"SELECT pg_current_wal_lsn() <= write_lsn FROM pg_stat_replication WHERE application_name = '$applname';";
$node_primary->poll_query_until('postgres', $caughtup_query)
or die "Timed out while waiting for replica 1 to catch up";
my $r1 = rand();
my $r2 = rand();
my $r3 = rand();
my $queries = qq(SET enable_seqscan=off;
SELECT * FROM tst ORDER BY v <-> '[$r1,$r2,$r3]' LIMIT 10;
);
# Run test queries and compare their result
my $primary_result = $node_primary->safe_psql("postgres", $queries);
my $replica_result = $node_replica->safe_psql("postgres", $queries);
is($primary_result, $replica_result, "$test_name: query result matches");
return;
}
# Initialize primary node
$node_primary = get_new_node('primary');
$node_primary->init(allows_streaming => 1);
$node_primary->start;
my $backup_name = 'my_backup';
# Take backup
$node_primary->backup($backup_name);
# Create streaming replica linking to primary
$node_replica = get_new_node('replica');
$node_replica->init_from_backup($node_primary, $backup_name,
has_streaming => 1);
$node_replica->start;
# Create ivfflat index on primary
$node_primary->safe_psql("postgres", "CREATE EXTENSION vector;");
$node_primary->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));");
$node_primary->safe_psql("postgres",
"INSERT INTO tst SELECT i%10, ARRAY[random(), random(), random()] FROM generate_series(1,100000) i;"
);
$node_primary->safe_psql("postgres",
"CREATE INDEX ON tst USING ivfflat (v);");
# Test that queries give same result
test_index_replay('initial');
# Run 10 cycles of table modification. Run test queries after each modification.
for my $i (1 .. 10)
{
$node_primary->safe_psql("postgres", "DELETE FROM tst WHERE i = $i;");
test_index_replay("delete $i");
$node_primary->safe_psql("postgres", "VACUUM tst;");
test_index_replay("vacuum $i");
my ($start, $end) = (100001 + ($i - 1) * 10000, 100000 + $i * 10000);
$node_primary->safe_psql("postgres",
"INSERT INTO tst SELECT i%10, ARRAY[random(), random(), random()] FROM generate_series($start,$end) i;"
);
test_index_replay("insert $i");
}

210
vector--0.1.0.sql Normal file
View File

@@ -0,0 +1,210 @@
-- complain if script is sourced in psql, rather than via CREATE EXTENSION
\echo Use "CREATE EXTENSION vector" to load this file. \quit
-- type
CREATE TYPE vector;
CREATE FUNCTION vector_in(cstring, oid, integer) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION vector_out(vector) RETURNS cstring
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION vector_typmod_in(cstring[]) RETURNS integer
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE TYPE vector (
INPUT = vector_in,
OUTPUT = vector_out,
TYPMOD_IN = vector_typmod_in
);
-- functions
CREATE FUNCTION l2_distance(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION inner_product(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION cosine_distance(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION vector_dims(vector) RETURNS integer
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION vector_norm(vector) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION vector_add(vector, vector) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION vector_sub(vector, vector) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
-- private functions
CREATE FUNCTION vector_lt(vector, vector) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION vector_le(vector, vector) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION vector_eq(vector, vector) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION vector_ne(vector, vector) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION vector_ge(vector, vector) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION vector_gt(vector, vector) RETURNS bool
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION vector_cmp(vector, vector) RETURNS int4
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION vector_l2_squared_distance(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION vector_negative_inner_product(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION vector_spherical_distance(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
-- cast functions
CREATE FUNCTION vector(vector, integer, boolean) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION array_to_vector(integer[], integer, boolean) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION array_to_vector(real[], integer, boolean) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
CREATE FUNCTION array_to_vector(double precision[], integer, boolean) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT;
-- casts
CREATE CAST (vector AS vector)
WITH FUNCTION vector(vector, integer, boolean) AS IMPLICIT;
CREATE CAST (integer[] AS vector)
WITH FUNCTION array_to_vector(integer[], integer, boolean) AS IMPLICIT;
CREATE CAST (real[] AS vector)
WITH FUNCTION array_to_vector(real[], integer, boolean) AS IMPLICIT;
CREATE CAST (double precision[] AS vector)
WITH FUNCTION array_to_vector(double precision[], integer, boolean) AS IMPLICIT;
-- operators
CREATE OPERATOR <-> (
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = l2_distance,
COMMUTATOR = '<->'
);
CREATE OPERATOR <#> (
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_negative_inner_product,
COMMUTATOR = '<#>'
);
CREATE OPERATOR <=> (
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = cosine_distance,
COMMUTATOR = '<=>'
);
CREATE OPERATOR + (
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_add,
COMMUTATOR = +
);
CREATE OPERATOR - (
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_sub,
COMMUTATOR = -
);
CREATE OPERATOR < (
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_lt,
COMMUTATOR = > , NEGATOR = >= ,
RESTRICT = scalarltsel, JOIN = scalarltjoinsel
);
-- should use scalarlesel and scalarlejoinsel, but not supported in Postgres < 11
CREATE OPERATOR <= (
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_le,
COMMUTATOR = >= , NEGATOR = > ,
RESTRICT = scalarltsel, JOIN = scalarltjoinsel
);
CREATE OPERATOR = (
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_eq,
COMMUTATOR = = , NEGATOR = <> ,
RESTRICT = eqsel, JOIN = eqjoinsel
);
CREATE OPERATOR <> (
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_ne,
COMMUTATOR = <> , NEGATOR = = ,
RESTRICT = eqsel, JOIN = eqjoinsel
);
-- should use scalargesel and scalargejoinsel, but not supported in Postgres < 11
CREATE OPERATOR >= (
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_ge,
COMMUTATOR = <= , NEGATOR = < ,
RESTRICT = scalargtsel, JOIN = scalargtjoinsel
);
CREATE OPERATOR > (
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_gt,
COMMUTATOR = < , NEGATOR = <= ,
RESTRICT = scalargtsel, JOIN = scalargtjoinsel
);
-- access method
CREATE FUNCTION ivfflathandler(internal) RETURNS index_am_handler
AS 'MODULE_PATHNAME' LANGUAGE C;
CREATE ACCESS METHOD ivfflat TYPE INDEX HANDLER ivfflathandler;
COMMENT ON ACCESS METHOD ivfflat IS 'ivfflat index access method';
-- opclasses
CREATE OPERATOR CLASS vector_ops
DEFAULT FOR TYPE vector USING btree AS
OPERATOR 1 < ,
OPERATOR 2 <= ,
OPERATOR 3 = ,
OPERATOR 4 >= ,
OPERATOR 5 > ,
FUNCTION 1 vector_cmp(vector, vector);
CREATE OPERATOR CLASS vector_l2_ops
DEFAULT FOR TYPE vector USING ivfflat AS
OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_l2_squared_distance(vector, vector),
FUNCTION 3 l2_distance(vector, vector);
CREATE OPERATOR CLASS vector_ip_ops
FOR TYPE vector USING ivfflat AS
OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector),
FUNCTION 3 vector_spherical_distance(vector, vector),
FUNCTION 4 vector_norm(vector);
CREATE OPERATOR CLASS vector_cosine_ops
FOR TYPE vector USING ivfflat AS
OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector),
FUNCTION 2 vector_norm(vector),
FUNCTION 3 vector_spherical_distance(vector, vector),
FUNCTION 4 vector_norm(vector);

610
vector.c Normal file
View File

@@ -0,0 +1,610 @@
#include "postgres.h"
#include <math.h>
#include "vector.h"
#include "fmgr.h"
#include "catalog/pg_type.h"
#include "lib/stringinfo.h"
#include "utils/array.h"
#include "utils/builtins.h"
#include "utils/lsyscache.h"
#if PG_VERSION_NUM >= 120000
#include "utils/float.h"
#endif
PG_MODULE_MAGIC;
/*
* Ensure same dimensions
*/
static inline void
CheckDims(Vector * a, Vector * b)
{
if (a->dim != b->dim)
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("different vector dimensions %d and %d", a->dim, b->dim)));
}
/*
* Ensure expected dimension
*/
static inline void
CheckExpectedDim(int32 typmod, int dim)
{
if (typmod != -1 && typmod != dim)
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("expected %d dimensions, not %d", typmod, dim)));
}
/*
* Ensure finite elements
*/
static inline void
CheckElement(float value)
{
if (isnan(value))
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("NaN not allowed in vector")));
if (isinf(value))
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("infinite value not allowed in vector")));
}
/*
* Print vector - useful for debugging
*/
void
PrintVector(char *msg, Vector * vector)
{
StringInfoData buf;
int dim = vector->dim;
int i;
initStringInfo(&buf);
appendStringInfoChar(&buf, '[');
for (i = 0; i < dim; i++)
{
if (i > 0)
appendStringInfoString(&buf, ",");
appendStringInfoString(&buf, float8out_internal(vector->x[i]));
}
appendStringInfoChar(&buf, ']');
elog(INFO, "%s = %s", msg, buf.data);
}
/*
* Convert textual representation to internal representation
*/
PG_FUNCTION_INFO_V1(vector_in);
Datum
vector_in(PG_FUNCTION_ARGS)
{
char *str = PG_GETARG_CSTRING(0);
int32 typmod = PG_GETARG_INT32(2);
int i;
double x[VECTOR_MAX_DIM];
int dim = 0;
char *pt;
char *stringEnd;
Vector *result;
if (*str != '[')
ereport(ERROR,
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
errmsg("malformed vector literal: \"%s\"", str),
errdetail("Vector contents must start with \"[\".")));
str++;
pt = strtok(str, ",");
stringEnd = pt;
while (pt != NULL && *stringEnd != ']')
{
if (dim == VECTOR_MAX_DIM)
ereport(ERROR,
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
errmsg("vector cannot have more than %d dimensions", VECTOR_MAX_DIM)));
x[dim] = strtod(pt, &stringEnd);
CheckElement(x[dim]);
dim++;
if (stringEnd == pt)
ereport(ERROR,
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
errmsg("invalid input syntax for type vector: \"%s\"", pt)));
if (*stringEnd != '\0' && *stringEnd != ']')
ereport(ERROR,
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
errmsg("invalid input syntax for type vector: \"%s\"", pt)));
pt = strtok(NULL, ",");
}
if (*stringEnd != ']')
ereport(ERROR,
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
errmsg("malformed vector literal"),
errdetail("Unexpected end of input.")));
if (stringEnd[1] != '\0')
ereport(ERROR,
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
errmsg("malformed vector literal"),
errdetail("Junk after closing right brace.")));
if (dim < 1)
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("vector must have at least 1 dimension")));
CheckExpectedDim(typmod, dim);
result = InitVector(dim);
for (i = 0; i < dim; i++)
result->x[i] = x[i];
PG_RETURN_POINTER(result);
}
/*
* Convert internal representation to textual representation
*/
PG_FUNCTION_INFO_V1(vector_out);
Datum
vector_out(PG_FUNCTION_ARGS)
{
Vector *vector = PG_GETARG_VECTOR_P(0);
StringInfoData buf;
int dim = vector->dim;
int i;
initStringInfo(&buf);
appendStringInfoChar(&buf, '[');
for (i = 0; i < dim; i++)
{
if (i > 0)
appendStringInfoString(&buf, ",");
appendStringInfoString(&buf, float8out_internal(vector->x[i]));
}
appendStringInfoChar(&buf, ']');
PG_FREE_IF_COPY(vector, 0);
PG_RETURN_CSTRING(buf.data);
}
/*
* Convert type modifier
*/
PG_FUNCTION_INFO_V1(vector_typmod_in);
Datum
vector_typmod_in(PG_FUNCTION_ARGS)
{
ArrayType *ta = PG_GETARG_ARRAYTYPE_P(0);
int32 *tl;
int n;
tl = ArrayGetIntegerTypmods(ta, &n);
if (n != 1)
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("invalid type modifier")));
if (*tl < 1)
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("dimensions for type vector must be at least 1")));
if (*tl > VECTOR_MAX_DIM)
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("dimensions for type vector cannot exceed %d", VECTOR_MAX_DIM)));
PG_RETURN_INT32(*tl);
}
/*
* Convert vector to vector
*/
PG_FUNCTION_INFO_V1(vector);
Datum
vector(PG_FUNCTION_ARGS)
{
Vector *arg = PG_GETARG_VECTOR_P(0);
int32 typmod = PG_GETARG_INT32(1);
CheckExpectedDim(typmod, arg->dim);
PG_RETURN_POINTER(arg);
}
/*
* Convert array to vector
*/
PG_FUNCTION_INFO_V1(array_to_vector);
Datum
array_to_vector(PG_FUNCTION_ARGS)
{
ArrayType *array = PG_GETARG_ARRAYTYPE_P(0);
int32 typmod = PG_GETARG_INT32(1);
int i;
Vector *result;
int16 typlen;
bool typbyval;
char typalign;
Datum *elemsp;
bool *nullsp;
int nelemsp;
if (ARR_NDIM(array) > 1)
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("array must be 1-D")));
get_typlenbyvalalign(ARR_ELEMTYPE(array), &typlen, &typbyval, &typalign);
deconstruct_array(array, ARR_ELEMTYPE(array), typlen, typbyval, typalign, &elemsp, &nullsp, &nelemsp);
if (typmod == -1)
{
if (nelemsp < 1)
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("vector must have at least 1 dimension")));
if (nelemsp > VECTOR_MAX_DIM)
ereport(ERROR,
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
errmsg("vector cannot have more than %d dimensions", VECTOR_MAX_DIM)));
}
else
CheckExpectedDim(typmod, nelemsp);
result = InitVector(nelemsp);
for (i = 0; i < nelemsp; i++)
{
if (nullsp[i])
ereport(ERROR,
(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
errmsg("array must not containing NULLs")));
if (ARR_ELEMTYPE(array) == INT4OID)
result->x[i] = DatumGetInt32(elemsp[i]);
else if (ARR_ELEMTYPE(array) == FLOAT8OID)
result->x[i] = DatumGetFloat8(elemsp[i]);
else if (ARR_ELEMTYPE(array) == FLOAT4OID)
result->x[i] = DatumGetFloat4(elemsp[i]);
else
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("unsupported array type")));
CheckElement(result->x[i]);
}
PG_RETURN_POINTER(result);
}
/*
* Get the L2 distance between vectors
*/
PG_FUNCTION_INFO_V1(l2_distance);
Datum
l2_distance(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
double distance = 0.0;
CheckDims(a, b);
for (int i = 0; i < a->dim; i++)
distance += pow(a->x[i] - b->x[i], 2);
PG_RETURN_FLOAT8(sqrt(distance));
}
/*
* Get the L2 squared distance between vectors
* This saves a sqrt calculation
*/
PG_FUNCTION_INFO_V1(vector_l2_squared_distance);
Datum
vector_l2_squared_distance(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
double distance = 0.0;
CheckDims(a, b);
for (int i = 0; i < a->dim; i++)
distance += pow(a->x[i] - b->x[i], 2);
PG_RETURN_FLOAT8(distance);
}
/*
* Get the inner product of two vectors
*/
PG_FUNCTION_INFO_V1(inner_product);
Datum
inner_product(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
double distance = 0.0;
CheckDims(a, b);
for (int i = 0; i < a->dim; i++)
distance += a->x[i] * b->x[i];
PG_RETURN_FLOAT8(distance);
}
/*
* Get the negative inner product of two vectors
*/
PG_FUNCTION_INFO_V1(vector_negative_inner_product);
Datum
vector_negative_inner_product(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
double distance = 0.0;
CheckDims(a, b);
for (int i = 0; i < a->dim; i++)
distance += a->x[i] * b->x[i];
PG_RETURN_FLOAT8(distance * -1);
}
/*
* Get the cosine distance between two vectors
*/
PG_FUNCTION_INFO_V1(cosine_distance);
Datum
cosine_distance(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
double distance = 0.0;
double norma = 0.0;
double normb = 0.0;
CheckDims(a, b);
for (int i = 0; i < a->dim; i++)
{
distance += a->x[i] * b->x[i];
norma += pow(a->x[i], 2);
normb += pow(b->x[i], 2);
}
PG_RETURN_FLOAT8(1 - (distance / (sqrt(norma) * sqrt(normb))));
}
/*
* Get the distance for spherical k-means
* Currently uses angular distance since needs to satisfy triangle inequality
* Assumes inputs are unit vectors (skips norm)
*/
PG_FUNCTION_INFO_V1(vector_spherical_distance);
Datum
vector_spherical_distance(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
double distance = 0.0;
CheckDims(a, b);
for (int i = 0; i < a->dim; i++)
distance += a->x[i] * b->x[i];
/* Prevent NaN with acos with loss of precision */
if (distance > 1)
distance = 1;
else if (distance < -1)
distance = -1;
PG_RETURN_FLOAT8(acos(distance) / M_PI);
}
/*
* Get the dimensions of a vector
*/
PG_FUNCTION_INFO_V1(vector_dims);
Datum
vector_dims(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
PG_RETURN_INT32(a->dim);
}
/*
* Get the L2 norm of a vector
*/
PG_FUNCTION_INFO_V1(vector_norm);
Datum
vector_norm(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
double norm = 0.0;
for (int i = 0; i < a->dim; i++)
norm += pow(a->x[i], 2);
PG_RETURN_FLOAT8(sqrt(norm));
}
/*
* Add vectors
*/
PG_FUNCTION_INFO_V1(vector_add);
Datum
vector_add(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
Vector *result;
int i;
CheckDims(a, b);
result = InitVector(a->dim);
for (i = 0; i < a->dim; i++)
result->x[i] = a->x[i] + b->x[i];
PG_RETURN_POINTER(result);
}
/*
* Subtract vectors
*/
PG_FUNCTION_INFO_V1(vector_sub);
Datum
vector_sub(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
Vector *result;
int i;
CheckDims(a, b);
result = InitVector(a->dim);
for (i = 0; i < a->dim; i++)
result->x[i] = a->x[i] - b->x[i];
PG_RETURN_POINTER(result);
}
/*
* Internal helper to compare vectors
*/
int
vector_cmp_internal(Vector * a, Vector * b)
{
int i;
CheckDims(a, b);
for (i = 0; i < a->dim; i++)
{
if (a->x[i] < b->x[i])
return -1;
if (a->x[i] > b->x[i])
return 1;
}
return 0;
}
/*
* Less than
*/
PG_FUNCTION_INFO_V1(vector_lt);
Datum
vector_lt(PG_FUNCTION_ARGS)
{
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
PG_RETURN_BOOL(vector_cmp_internal(a, b) < 0);
}
/*
* Less than or equal
*/
PG_FUNCTION_INFO_V1(vector_le);
Datum
vector_le(PG_FUNCTION_ARGS)
{
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
PG_RETURN_BOOL(vector_cmp_internal(a, b) <= 0);
}
/*
* Equal
*/
PG_FUNCTION_INFO_V1(vector_eq);
Datum
vector_eq(PG_FUNCTION_ARGS)
{
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
PG_RETURN_BOOL(vector_cmp_internal(a, b) == 0);
}
/*
* Not equal
*/
PG_FUNCTION_INFO_V1(vector_ne);
Datum
vector_ne(PG_FUNCTION_ARGS)
{
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
PG_RETURN_BOOL(vector_cmp_internal(a, b) != 0);
}
/*
* Greater than or equal
*/
PG_FUNCTION_INFO_V1(vector_ge);
Datum
vector_ge(PG_FUNCTION_ARGS)
{
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
PG_RETURN_BOOL(vector_cmp_internal(a, b) >= 0);
}
/*
* Greater than
*/
PG_FUNCTION_INFO_V1(vector_gt);
Datum
vector_gt(PG_FUNCTION_ARGS)
{
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
PG_RETURN_BOOL(vector_cmp_internal(a, b) > 0);
}
/*
* Compare vectors
*/
PG_FUNCTION_INFO_V1(vector_cmp);
Datum
vector_cmp(PG_FUNCTION_ARGS)
{
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
PG_RETURN_INT32(vector_cmp_internal(a, b));
}

4
vector.control Normal file
View File

@@ -0,0 +1,4 @@
comment = 'vector data type and ivfflat access method'
default_version = '0.1.0'
module_pathname = '$libdir/vector'
relocatable = true

41
vector.h Normal file
View File

@@ -0,0 +1,41 @@
#ifndef VECTOR_H
#define VECTOR_H
#include "postgres.h"
#define VECTOR_MAX_DIM 1024
#define VECTOR_SIZE(_dim) (offsetof(Vector, x) + sizeof(float)*(_dim))
#define DatumGetVector(x) ((Vector *) PG_DETOAST_DATUM(x))
#define PG_GETARG_VECTOR_P(x) DatumGetVector(PG_GETARG_DATUM(x))
#define PG_RETURN_VECTOR_P(x) PG_RETURN_POINTER(x)
typedef struct Vector
{
int32 vl_len_; /* varlena header (do not touch directly!) */
int16 dim; /* number of dimensions */
int16 unused;
float x[FLEXIBLE_ARRAY_MEMBER];
} Vector;
void PrintVector(char *msg, Vector * vector);
int vector_cmp_internal(Vector * a, Vector * b);
/*
* Allocate and initialize a new vector
*/
static inline Vector *
InitVector(int dim)
{
Vector *result;
int size;
size = VECTOR_SIZE(dim);
result = (Vector *) palloc0(size);
SET_VARSIZE(result, size);
result->dim = dim;
return result;
}
#endif