mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-30 17:51:18 +08:00
Added basic support for float4 arrays
This commit is contained in:
2
Makefile
2
Makefile
@@ -3,7 +3,7 @@ EXTVERSION = 0.5.0
|
||||
|
||||
MODULE_big = vector
|
||||
DATA = $(wildcard sql/*--*.sql)
|
||||
OBJS = src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o
|
||||
OBJS = src/float4.o src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o
|
||||
HEADERS = src/vector.h
|
||||
|
||||
TESTS = $(wildcard test/sql/*.sql)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
EXTENSION = vector
|
||||
EXTVERSION = 0.5.0
|
||||
|
||||
OBJS = src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj
|
||||
OBJS = src\float4.obj src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj
|
||||
HEADERS = src\vector.h
|
||||
|
||||
REGRESS = btree cast copy functions input ivfflat_cosine ivfflat_ip ivfflat_l2 ivfflat_options ivfflat_unlogged
|
||||
|
||||
@@ -34,6 +34,9 @@ CREATE TYPE vector (
|
||||
CREATE FUNCTION l2_distance(vector, vector) RETURNS float8
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION float4_l2_distance(float4[], float4[]) RETURNS float8
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION inner_product(vector, vector) RETURNS float8
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
@@ -84,6 +87,9 @@ CREATE FUNCTION vector_cmp(vector, vector) RETURNS int4
|
||||
CREATE FUNCTION vector_l2_squared_distance(vector, vector) RETURNS float8
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION float4_l2_squared_distance(float4[], float4[]) RETURNS float8
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION vector_negative_inner_product(vector, vector) RETURNS float8
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
@@ -164,6 +170,11 @@ CREATE OPERATOR <-> (
|
||||
COMMUTATOR = '<->'
|
||||
);
|
||||
|
||||
CREATE OPERATOR <-> (
|
||||
LEFTARG = float4[], RIGHTARG = float4[], PROCEDURE = float4_l2_distance,
|
||||
COMMUTATOR = '<->'
|
||||
);
|
||||
|
||||
CREATE OPERATOR <#> (
|
||||
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_negative_inner_product,
|
||||
COMMUTATOR = '<#>'
|
||||
@@ -280,6 +291,11 @@ CREATE OPERATOR CLASS vector_l2_ops
|
||||
OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 vector_l2_squared_distance(vector, vector);
|
||||
|
||||
CREATE OPERATOR CLASS float4_l2_ops
|
||||
FOR TYPE float4[] USING hnsw AS
|
||||
OPERATOR 1 <-> (float4[], float4[]) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 float4_l2_squared_distance(float4[], float4[]);
|
||||
|
||||
CREATE OPERATOR CLASS vector_ip_ops
|
||||
FOR TYPE vector USING hnsw AS
|
||||
OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops,
|
||||
|
||||
58
src/float4.c
Normal file
58
src/float4.c
Normal file
@@ -0,0 +1,58 @@
|
||||
#include "postgres.h"
|
||||
|
||||
#include "utils/array.h"
|
||||
|
||||
/*
|
||||
* Get the L2 distance between vectors
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(float4_l2_distance);
|
||||
Datum
|
||||
float4_l2_distance(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *a = PG_GETARG_ARRAYTYPE_P(0);
|
||||
ArrayType *b = PG_GETARG_ARRAYTYPE_P(1);
|
||||
float *ax = (float *) ARR_DATA_PTR(a);
|
||||
float *bx = (float *) ARR_DATA_PTR(b);
|
||||
float distance = 0.0;
|
||||
float diff;
|
||||
|
||||
/* TODO Check rank, dimensions, and nulls */
|
||||
int dim = ARR_DIMS(a)[0];
|
||||
|
||||
/* Auto-vectorized */
|
||||
for (int i = 0; i < dim; i++)
|
||||
{
|
||||
diff = ax[i] - bx[i];
|
||||
distance += diff * diff;
|
||||
}
|
||||
|
||||
PG_RETURN_FLOAT8(sqrt((double) distance));
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the L2 squared distance between vectors
|
||||
* This saves a sqrt calculation
|
||||
*/
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(float4_l2_squared_distance);
|
||||
Datum
|
||||
float4_l2_squared_distance(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *a = PG_GETARG_ARRAYTYPE_P(0);
|
||||
ArrayType *b = PG_GETARG_ARRAYTYPE_P(1);
|
||||
float *ax = (float *) ARR_DATA_PTR(a);
|
||||
float *bx = (float *) ARR_DATA_PTR(b);
|
||||
float distance = 0.0;
|
||||
float diff;
|
||||
|
||||
/* TODO Check rank, dimensions, and nulls */
|
||||
int dim = ARR_DIMS(a)[0];
|
||||
|
||||
/* Auto-vectorized */
|
||||
for (int i = 0; i < dim; i++)
|
||||
{
|
||||
diff = ax[i] - bx[i];
|
||||
distance += diff * diff;
|
||||
}
|
||||
|
||||
PG_RETURN_FLOAT8((double) distance);
|
||||
}
|
||||
@@ -33,6 +33,12 @@ HnswInit(void)
|
||||
HNSW_DEFAULT_EF_CONSTRUCTION, HNSW_MIN_EF_CONSTRUCTION, HNSW_MAX_EF_CONSTRUCTION
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
,AccessExclusiveLock
|
||||
#endif
|
||||
);
|
||||
add_int_reloption(hnsw_relopt_kind, "dimensions", "Number of dimensions",
|
||||
HNSW_DEFAULT_DIMENSIONS, HNSW_MIN_DIMENSIONS, HNSW_MAX_DIMENSIONS
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
,AccessExclusiveLock
|
||||
#endif
|
||||
);
|
||||
|
||||
@@ -125,6 +131,7 @@ hnswoptions(Datum reloptions, bool validate)
|
||||
static const relopt_parse_elt tab[] = {
|
||||
{"m", RELOPT_TYPE_INT, offsetof(HnswOptions, m)},
|
||||
{"ef_construction", RELOPT_TYPE_INT, offsetof(HnswOptions, efConstruction)},
|
||||
{"dimensions", RELOPT_TYPE_INT, offsetof(HnswOptions, dimensions)},
|
||||
};
|
||||
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
|
||||
@@ -42,6 +42,9 @@
|
||||
#define HNSW_DEFAULT_EF_SEARCH 40
|
||||
#define HNSW_MIN_EF_SEARCH 1
|
||||
#define HNSW_MAX_EF_SEARCH 1000
|
||||
#define HNSW_DEFAULT_DIMENSIONS -1
|
||||
#define HNSW_MIN_DIMENSIONS 1
|
||||
#define HNSW_MAX_DIMENSIONS HNSW_MAX_DIM
|
||||
|
||||
/* Tuple types */
|
||||
#define HNSW_ELEMENT_TUPLE_TYPE 1
|
||||
@@ -131,6 +134,7 @@ typedef struct HnswOptions
|
||||
int32 vl_len_; /* varlena header (do not touch directly!) */
|
||||
int m; /* number of connections */
|
||||
int efConstruction; /* size of dynamic candidate list */
|
||||
int dimensions;
|
||||
} HnswOptions;
|
||||
|
||||
typedef struct HnswBuildState
|
||||
@@ -259,6 +263,7 @@ typedef struct HnswVacuumState
|
||||
/* Methods */
|
||||
int HnswGetM(Relation index);
|
||||
int HnswGetEfConstruction(Relation index);
|
||||
int HnswGetDimensions(Relation index);
|
||||
FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum);
|
||||
bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result);
|
||||
void HnswCommitBuffer(Buffer buf, GenericXLogState *state);
|
||||
|
||||
@@ -400,6 +400,7 @@ HnswGetMaxInMemoryElements(int m, double ml, int dimensions)
|
||||
elementSize += sizeof(HnswNeighborArray) * (avgLevel + 1);
|
||||
elementSize += sizeof(HnswCandidate) * (m * (avgLevel + 2));
|
||||
elementSize += sizeof(ItemPointerData);
|
||||
/* TODO Handle non-vector types */
|
||||
elementSize += VECTOR_SIZE(dimensions);
|
||||
return (maintenance_work_mem * 1024L) / elementSize;
|
||||
}
|
||||
@@ -417,7 +418,10 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
|
||||
|
||||
buildstate->m = HnswGetM(index);
|
||||
buildstate->efConstruction = HnswGetEfConstruction(index);
|
||||
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
|
||||
buildstate->dimensions = HnswGetDimensions(index);
|
||||
|
||||
if (buildstate->dimensions < 0)
|
||||
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
|
||||
|
||||
/* Require column to have dimensions to be indexed */
|
||||
if (buildstate->dimensions < 0)
|
||||
|
||||
@@ -35,6 +35,20 @@ HnswGetEfConstruction(Relation index)
|
||||
return HNSW_DEFAULT_EF_CONSTRUCTION;
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the number of dimensions in the index
|
||||
*/
|
||||
int
|
||||
HnswGetDimensions(Relation index)
|
||||
{
|
||||
HnswOptions *opts = (HnswOptions *) index->rd_options;
|
||||
|
||||
if (opts)
|
||||
return opts->dimensions;
|
||||
|
||||
return HNSW_DEFAULT_DIMENSIONS;
|
||||
}
|
||||
|
||||
/*
|
||||
* Get proc
|
||||
*/
|
||||
|
||||
93
test/t/019_hnsw_array.pl
Normal file
93
test/t/019_hnsw_array.pl
Normal file
@@ -0,0 +1,93 @@
|
||||
use strict;
|
||||
use warnings;
|
||||
use PostgresNode;
|
||||
use TestLib;
|
||||
use Test::More;
|
||||
|
||||
my $node;
|
||||
my @queries = ();
|
||||
my @expected;
|
||||
my $limit = 20;
|
||||
|
||||
sub test_recall
|
||||
{
|
||||
my ($min, $operator) = @_;
|
||||
my $correct = 0;
|
||||
my $total = 0;
|
||||
|
||||
my $explain = $node->safe_psql("postgres", qq(
|
||||
SET enable_seqscan = off;
|
||||
EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit;
|
||||
));
|
||||
like($explain, qr/Index Scan/);
|
||||
|
||||
for my $i (0 .. $#queries)
|
||||
{
|
||||
my $actual = $node->safe_psql("postgres", qq(
|
||||
SET enable_seqscan = off;
|
||||
SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit;
|
||||
));
|
||||
my @actual_ids = split("\n", $actual);
|
||||
my %actual_set = map { $_ => 1 } @actual_ids;
|
||||
|
||||
my @expected_ids = split("\n", $expected[$i]);
|
||||
|
||||
foreach (@expected_ids)
|
||||
{
|
||||
if (exists($actual_set{$_}))
|
||||
{
|
||||
$correct++;
|
||||
}
|
||||
$total++;
|
||||
}
|
||||
}
|
||||
|
||||
cmp_ok($correct / $total, ">=", $min, $operator);
|
||||
}
|
||||
|
||||
# Initialize node
|
||||
$node = get_new_node('node');
|
||||
$node->init;
|
||||
$node->start;
|
||||
|
||||
# Create table
|
||||
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
|
||||
$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v float4[3]);");
|
||||
$node->safe_psql("postgres",
|
||||
"INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;"
|
||||
);
|
||||
|
||||
# Generate queries
|
||||
for (1 .. 20)
|
||||
{
|
||||
my $r1 = rand();
|
||||
my $r2 = rand();
|
||||
my $r3 = rand();
|
||||
push(@queries, "{$r1,$r2,$r3}");
|
||||
}
|
||||
|
||||
# Check each index type
|
||||
my @operators = ("<->");
|
||||
my @opclasses = ("float4_l2_ops");
|
||||
|
||||
for my $i (0 .. $#operators)
|
||||
{
|
||||
my $operator = $operators[$i];
|
||||
my $opclass = $opclasses[$i];
|
||||
|
||||
# Get exact results
|
||||
@expected = ();
|
||||
foreach (@queries)
|
||||
{
|
||||
my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;");
|
||||
push(@expected, $res);
|
||||
}
|
||||
|
||||
# Add index
|
||||
$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v $opclass) WITH (dimensions = 3);");
|
||||
|
||||
my $min = $operator eq "<#>" ? 0.80 : 0.99;
|
||||
test_recall($min, $operator);
|
||||
}
|
||||
|
||||
done_testing();
|
||||
Reference in New Issue
Block a user