From 72e9cf06c1f6b47a7454631ddb041bc481fd48ae Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Wed, 13 Sep 2023 13:41:06 -0700 Subject: [PATCH] Added basic support for float4 arrays --- Makefile | 2 +- Makefile.win | 2 +- sql/vector.sql | 16 +++++++ src/float4.c | 58 +++++++++++++++++++++++++ src/hnsw.c | 7 +++ src/hnsw.h | 5 +++ src/hnswbuild.c | 6 ++- src/hnswutils.c | 14 ++++++ test/t/019_hnsw_array.pl | 93 ++++++++++++++++++++++++++++++++++++++++ 9 files changed, 200 insertions(+), 3 deletions(-) create mode 100644 src/float4.c create mode 100644 test/t/019_hnsw_array.pl diff --git a/Makefile b/Makefile index 1412ad7..9d9961a 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ EXTVERSION = 0.5.0 MODULE_big = vector DATA = $(wildcard sql/*--*.sql) -OBJS = src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o +OBJS = src/float4.o src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o HEADERS = src/vector.h TESTS = $(wildcard test/sql/*.sql) diff --git a/Makefile.win b/Makefile.win index 7bcae57..3e0e65e 100644 --- a/Makefile.win +++ b/Makefile.win @@ -1,7 +1,7 @@ EXTENSION = vector EXTVERSION = 0.5.0 -OBJS = src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj +OBJS = src\float4.obj src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj HEADERS = src\vector.h REGRESS = btree cast copy functions input ivfflat_cosine ivfflat_ip ivfflat_l2 ivfflat_options ivfflat_unlogged diff --git a/sql/vector.sql b/sql/vector.sql index 137931f..190c38e 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -34,6 +34,9 @@ CREATE TYPE vector ( CREATE FUNCTION l2_distance(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION float4_l2_distance(float4[], float4[]) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION inner_product(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; @@ -84,6 +87,9 @@ CREATE FUNCTION vector_cmp(vector, vector) RETURNS int4 CREATE FUNCTION vector_l2_squared_distance(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION float4_l2_squared_distance(float4[], float4[]) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION vector_negative_inner_product(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; @@ -164,6 +170,11 @@ CREATE OPERATOR <-> ( COMMUTATOR = '<->' ); +CREATE OPERATOR <-> ( + LEFTARG = float4[], RIGHTARG = float4[], PROCEDURE = float4_l2_distance, + COMMUTATOR = '<->' +); + CREATE OPERATOR <#> ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_negative_inner_product, COMMUTATOR = '<#>' @@ -280,6 +291,11 @@ CREATE OPERATOR CLASS vector_l2_ops OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_l2_squared_distance(vector, vector); +CREATE OPERATOR CLASS float4_l2_ops + FOR TYPE float4[] USING hnsw AS + OPERATOR 1 <-> (float4[], float4[]) FOR ORDER BY float_ops, + FUNCTION 1 float4_l2_squared_distance(float4[], float4[]); + CREATE OPERATOR CLASS vector_ip_ops FOR TYPE vector USING hnsw AS OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops, diff --git a/src/float4.c b/src/float4.c new file mode 100644 index 0000000..5d5e650 --- /dev/null +++ b/src/float4.c @@ -0,0 +1,58 @@ +#include "postgres.h" + +#include "utils/array.h" + +/* + * Get the L2 distance between vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(float4_l2_distance); +Datum +float4_l2_distance(PG_FUNCTION_ARGS) +{ + ArrayType *a = PG_GETARG_ARRAYTYPE_P(0); + ArrayType *b = PG_GETARG_ARRAYTYPE_P(1); + float *ax = (float *) ARR_DATA_PTR(a); + float *bx = (float *) ARR_DATA_PTR(b); + float distance = 0.0; + float diff; + + /* TODO Check rank, dimensions, and nulls */ + int dim = ARR_DIMS(a)[0]; + + /* Auto-vectorized */ + for (int i = 0; i < dim; i++) + { + diff = ax[i] - bx[i]; + distance += diff * diff; + } + + PG_RETURN_FLOAT8(sqrt((double) distance)); +} + +/* + * Get the L2 squared distance between vectors + * This saves a sqrt calculation + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(float4_l2_squared_distance); +Datum +float4_l2_squared_distance(PG_FUNCTION_ARGS) +{ + ArrayType *a = PG_GETARG_ARRAYTYPE_P(0); + ArrayType *b = PG_GETARG_ARRAYTYPE_P(1); + float *ax = (float *) ARR_DATA_PTR(a); + float *bx = (float *) ARR_DATA_PTR(b); + float distance = 0.0; + float diff; + + /* TODO Check rank, dimensions, and nulls */ + int dim = ARR_DIMS(a)[0]; + + /* Auto-vectorized */ + for (int i = 0; i < dim; i++) + { + diff = ax[i] - bx[i]; + distance += diff * diff; + } + + PG_RETURN_FLOAT8((double) distance); +} diff --git a/src/hnsw.c b/src/hnsw.c index 758e418..042d045 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -33,6 +33,12 @@ HnswInit(void) HNSW_DEFAULT_EF_CONSTRUCTION, HNSW_MIN_EF_CONSTRUCTION, HNSW_MAX_EF_CONSTRUCTION #if PG_VERSION_NUM >= 130000 ,AccessExclusiveLock +#endif + ); + add_int_reloption(hnsw_relopt_kind, "dimensions", "Number of dimensions", + HNSW_DEFAULT_DIMENSIONS, HNSW_MIN_DIMENSIONS, HNSW_MAX_DIMENSIONS +#if PG_VERSION_NUM >= 130000 + ,AccessExclusiveLock #endif ); @@ -125,6 +131,7 @@ hnswoptions(Datum reloptions, bool validate) static const relopt_parse_elt tab[] = { {"m", RELOPT_TYPE_INT, offsetof(HnswOptions, m)}, {"ef_construction", RELOPT_TYPE_INT, offsetof(HnswOptions, efConstruction)}, + {"dimensions", RELOPT_TYPE_INT, offsetof(HnswOptions, dimensions)}, }; #if PG_VERSION_NUM >= 130000 diff --git a/src/hnsw.h b/src/hnsw.h index 4a5664f..e0aa8ea 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -42,6 +42,9 @@ #define HNSW_DEFAULT_EF_SEARCH 40 #define HNSW_MIN_EF_SEARCH 1 #define HNSW_MAX_EF_SEARCH 1000 +#define HNSW_DEFAULT_DIMENSIONS -1 +#define HNSW_MIN_DIMENSIONS 1 +#define HNSW_MAX_DIMENSIONS HNSW_MAX_DIM /* Tuple types */ #define HNSW_ELEMENT_TUPLE_TYPE 1 @@ -131,6 +134,7 @@ typedef struct HnswOptions int32 vl_len_; /* varlena header (do not touch directly!) */ int m; /* number of connections */ int efConstruction; /* size of dynamic candidate list */ + int dimensions; } HnswOptions; typedef struct HnswBuildState @@ -259,6 +263,7 @@ typedef struct HnswVacuumState /* Methods */ int HnswGetM(Relation index); int HnswGetEfConstruction(Relation index); +int HnswGetDimensions(Relation index); FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum); bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result); void HnswCommitBuffer(Buffer buf, GenericXLogState *state); diff --git a/src/hnswbuild.c b/src/hnswbuild.c index 49a35e6..3f90c4e 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -400,6 +400,7 @@ HnswGetMaxInMemoryElements(int m, double ml, int dimensions) elementSize += sizeof(HnswNeighborArray) * (avgLevel + 1); elementSize += sizeof(HnswCandidate) * (m * (avgLevel + 2)); elementSize += sizeof(ItemPointerData); + /* TODO Handle non-vector types */ elementSize += VECTOR_SIZE(dimensions); return (maintenance_work_mem * 1024L) / elementSize; } @@ -417,7 +418,10 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->m = HnswGetM(index); buildstate->efConstruction = HnswGetEfConstruction(index); - buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; + buildstate->dimensions = HnswGetDimensions(index); + + if (buildstate->dimensions < 0) + buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; /* Require column to have dimensions to be indexed */ if (buildstate->dimensions < 0) diff --git a/src/hnswutils.c b/src/hnswutils.c index de816e2..a872a93 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -35,6 +35,20 @@ HnswGetEfConstruction(Relation index) return HNSW_DEFAULT_EF_CONSTRUCTION; } +/* + * Get the number of dimensions in the index + */ +int +HnswGetDimensions(Relation index) +{ + HnswOptions *opts = (HnswOptions *) index->rd_options; + + if (opts) + return opts->dimensions; + + return HNSW_DEFAULT_DIMENSIONS; +} + /* * Get proc */ diff --git a/test/t/019_hnsw_array.pl b/test/t/019_hnsw_array.pl new file mode 100644 index 0000000..ca25370 --- /dev/null +++ b/test/t/019_hnsw_array.pl @@ -0,0 +1,93 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; + +sub test_recall +{ + my ($min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v float4[3]);"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my $r1 = rand(); + my $r2 = rand(); + my $r3 = rand(); + push(@queries, "{$r1,$r2,$r3}"); +} + +# Check each index type +my @operators = ("<->"); +my @opclasses = ("float4_l2_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + # Get exact results + @expected = (); + foreach (@queries) + { + my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;"); + push(@expected, $res); + } + + # Add index + $node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v $opclass) WITH (dimensions = 3);"); + + my $min = $operator eq "<#>" ? 0.80 : 0.99; + test_recall($min, $operator); +} + +done_testing();