diff --git a/README.md b/README.md index 21868c8..db37a1f 100644 --- a/README.md +++ b/README.md @@ -346,6 +346,7 @@ CREATE INDEX ON items USING ivfflat (embedding vector_cosine_ops) WITH (lists = Supported types are: - `vector` - up to 2,000 dimensions +- `halfvec` - up to 4,000 dimensions (unreleased) ### Query Options diff --git a/sql/vector--0.6.2--0.7.0.sql b/sql/vector--0.6.2--0.7.0.sql index 26ac3ed..cc641ed 100644 --- a/sql/vector--0.6.2--0.7.0.sql +++ b/sql/vector--0.6.2--0.7.0.sql @@ -140,6 +140,27 @@ CREATE OPERATOR <=> ( COMMUTATOR = '<=>' ); +CREATE OPERATOR CLASS halfvec_l2_ops + FOR TYPE halfvec USING ivfflat AS + OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, + FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), + FUNCTION 3 l2_distance(halfvec, halfvec); + +CREATE OPERATOR CLASS halfvec_ip_ops + FOR TYPE halfvec USING ivfflat AS + OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops, + FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), + FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), + FUNCTION 4 halfvec_norm(halfvec); + +CREATE OPERATOR CLASS halfvec_cosine_ops + FOR TYPE halfvec USING ivfflat AS + OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops, + FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), + FUNCTION 2 halfvec_norm(halfvec), + FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), + FUNCTION 4 halfvec_norm(halfvec); + CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, diff --git a/sql/vector.sql b/sql/vector.sql index 8b396da..a0d59dd 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -443,6 +443,27 @@ CREATE OPERATOR <=> ( -- halfvec opclasses +CREATE OPERATOR CLASS halfvec_l2_ops + FOR TYPE halfvec USING ivfflat AS + OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, + FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec), + FUNCTION 3 l2_distance(halfvec, halfvec); + +CREATE OPERATOR CLASS halfvec_ip_ops + FOR TYPE halfvec USING ivfflat AS + OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops, + FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), + FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), + FUNCTION 4 halfvec_norm(halfvec); + +CREATE OPERATOR CLASS halfvec_cosine_ops + FOR TYPE halfvec USING ivfflat AS + OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops, + FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec), + FUNCTION 2 halfvec_norm(halfvec), + FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec), + FUNCTION 4 halfvec_norm(halfvec); + CREATE OPERATOR CLASS halfvec_l2_ops FOR TYPE halfvec USING hnsw AS OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops, diff --git a/src/halfvec.c b/src/halfvec.c index b87a562..f0bd583 100644 --- a/src/halfvec.c +++ b/src/halfvec.c @@ -27,32 +27,6 @@ #include #endif -/* - * Check if half is NaN - */ -static inline bool -HalfIsNan(half num) -{ -#ifdef FLT16_SUPPORT - return isnan(num); -#else - return (num & 0x7C00) == 0x7C00 && (num & 0x7FFF) != 0x7C00; -#endif -} - -/* - * Check if half is infinite - */ -static inline bool -HalfIsInf(half num) -{ -#ifdef FLT16_SUPPORT - return isinf(num); -#else - return (num & 0x7FFF) == 0x7C00; -#endif -} - /* * Get a half from a message buffer */ diff --git a/src/halfvec.h b/src/halfvec.h index 8f8abdb..db3b6fa 100644 --- a/src/halfvec.h +++ b/src/halfvec.h @@ -47,4 +47,24 @@ half Float4ToHalf(float num); half Float4ToHalfUnchecked(float num); int halfvec_cmp_internal(HalfVector * a, HalfVector * b); +static inline bool +HalfIsNan(half num) +{ +#ifdef FLT16_SUPPORT + return isnan(num); +#else + return (num & 0x7C00) == 0x7C00 && (num & 0x7FFF) != 0x7C00; +#endif +} + +static inline bool +HalfIsInf(half num) +{ +#ifdef FLT16_SUPPORT + return isinf(num); +#else + return (num & 0x7FFF) == 0x7C00; +#endif +} + #endif diff --git a/src/ivfbuild.c b/src/ivfbuild.c index cf01334..346f03e 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -10,12 +10,14 @@ #include "catalog/pg_operator_d.h" #include "catalog/pg_type_d.h" #include "commands/progress.h" +#include "halfvec.h" #include "ivfflat.h" #include "miscadmin.h" #include "optimizer/optimizer.h" #include "storage/bufmgr.h" #include "tcop/tcopprot.h" #include "utils/memutils.h" +#include "vector.h" #if PG_VERSION_NUM >= 140000 #include "utils/backend_progress.h" @@ -367,7 +369,7 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual); - buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, VECTOR_SIZE(buildstate->dimensions)); + buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions, buildstate->type == IVFFLAT_TYPE_HALFVEC ? HALFVEC_SIZE(buildstate->dimensions) : VECTOR_SIZE(buildstate->dimensions)); buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists); buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, diff --git a/src/ivfflat.h b/src/ivfflat.h index 034a039..3ca9178 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -45,7 +45,8 @@ typedef enum IvfflatType { - IVFFLAT_TYPE_VECTOR + IVFFLAT_TYPE_VECTOR, + IVFFLAT_TYPE_HALFVEC } IvfflatType; /* Build phases */ diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index 27ef778..e4b73e3 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -3,10 +3,12 @@ #include #include +#include "halfvec.h" #include "ivfflat.h" #include "miscadmin.h" #include "utils/datum.h" #include "utils/memutils.h" +#include "vector.h" /* * Initialize with kmeans++ @@ -101,6 +103,13 @@ ApplyNorm(FmgrInfo *normprocinfo, Oid collation, Datum value, IvfflatType type) for (int i = 0; i < vec->dim; i++) vec->x[i] /= norm; } + else if (type == IVFFLAT_TYPE_HALFVEC) + { + HalfVector *vec = DatumGetHalfVector(value); + + for (int i = 0; i < vec->dim; i++) + vec->x[i] = Float4ToHalfUnchecked(HalfToFloat4(vec->x[i]) / norm); + } else elog(ERROR, "Unsupported type"); } @@ -115,6 +124,15 @@ CompareVectors(const void *a, const void *b) return vector_cmp_internal((Vector *) a, (Vector *) b); } +/* + * Compare half vectors + */ +static int +CompareHalfVectors(const void *a, const void *b) +{ + return halfvec_cmp_internal((HalfVector *) a, (HalfVector *) b); +} + /* * Quick approach if we have little data */ @@ -130,6 +148,8 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy { if (type == IVFFLAT_TYPE_VECTOR) qsort(samples->items, samples->length, samples->itemsize, CompareVectors); + else if (type == IVFFLAT_TYPE_HALFVEC) + qsort(samples->items, samples->length, samples->itemsize, CompareHalfVectors); else elog(ERROR, "Unsupported type"); @@ -160,6 +180,16 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers, IvfflatTy for (int j = 0; j < dimensions; j++) vec->x[j] = RandomDouble(); } + else if (type == IVFFLAT_TYPE_HALFVEC) + { + HalfVector *vec = DatumGetHalfVector(center); + + SET_VARSIZE(vec, HALFVEC_SIZE(dimensions)); + vec->dim = dimensions; + + for (int j = 0; j < dimensions; j++) + vec->x[j] = Float4ToHalfUnchecked((float) RandomDouble()); + } else elog(ERROR, "Unsupported type"); @@ -221,6 +251,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->itemsize); Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->itemsize); Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, centers->itemsize); + Size aggCentersSize = type == IVFFLAT_TYPE_VECTOR ? 0 : VECTOR_ARRAY_SIZE(numCenters, VECTOR_SIZE(dimensions)); Size centerCountsSize = sizeof(int) * numCenters; Size closestCentersSize = sizeof(int) * numSamples; Size lowerBoundSize = sizeof(float) * numSamples * numCenters; @@ -230,7 +261,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp Size newcdistSize = sizeof(float) * numCenters; /* Calculate total size */ - Size totalSize = samplesSize + centersSize + newCentersSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize; + Size totalSize = samplesSize + centersSize + newCentersSize + aggCentersSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize; /* Check memory requirements */ /* Add one to error message to ceil */ @@ -265,7 +296,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp halfcdist = palloc_extended(halfcdistSize, MCXT_ALLOC_HUGE); newcdist = palloc(newcdistSize); - aggCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize); + aggCenters = VectorArrayInit(numCenters, dimensions, VECTOR_SIZE(dimensions)); for (int64 j = 0; j < numCenters; j++) { Vector *vec = (Vector *) VectorArrayGet(aggCenters, j); @@ -276,6 +307,18 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp if (type == IVFFLAT_TYPE_VECTOR) newCenters = aggCenters; + else if (type == IVFFLAT_TYPE_HALFVEC) + { + newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize); + + for (int j = 0; j < numCenters; j++) + { + HalfVector *vec = (HalfVector *) VectorArrayGet(newCenters, j); + + SET_VARSIZE(vec, HALFVEC_SIZE(dimensions)); + vec->dim = dimensions; + } + } else elog(ERROR, "Unsupported type"); @@ -430,20 +473,32 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp for (int64 j = 0; j < numSamples; j++) { int closestCenter = closestCenters[j]; - Vector *vec = (Vector *) VectorArrayGet(samples, j); - Vector *newCenter = (Vector *) VectorArrayGet(aggCenters, closestCenter); + Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, closestCenter); /* Increment sum and count of closest center */ - for (int64 k = 0; k < dimensions; k++) - newCenter->x[k] += vec->x[k]; + if (type == IVFFLAT_TYPE_VECTOR) + { + Vector *vec = (Vector *) VectorArrayGet(samples, j); + + for (int64 k = 0; k < dimensions; k++) + aggCenter->x[k] += vec->x[k]; + } + else if (type == IVFFLAT_TYPE_HALFVEC) + { + HalfVector *vec = (HalfVector *) VectorArrayGet(samples, j); + + for (int64 k = 0; k < dimensions; k++) + aggCenter->x[k] += HalfToFloat4(vec->x[k]); + } + else + elog(ERROR, "Unsupported type"); centerCounts[closestCenter] += 1; } for (int64 j = 0; j < numCenters; j++) { - Datum center = PointerGetDatum(VectorArrayGet(aggCenters, j)); - Vector *vec = DatumGetVector(center); + Vector *vec = (Vector *) VectorArrayGet(aggCenters, j); if (centerCounts[j] > 0) { @@ -464,10 +519,28 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatTyp for (int64 k = 0; k < dimensions; k++) vec->x[k] = RandomDouble(); } + } - /* Normalize if needed */ - if (normprocinfo != NULL) - ApplyNorm(normprocinfo, collation, center, type); + if (type == IVFFLAT_TYPE_HALFVEC) + { + for (int j = 0; j < numCenters; j++) + { + Vector *aggCenter = (Vector *) VectorArrayGet(aggCenters, j); + HalfVector *newCenter = (HalfVector *) VectorArrayGet(newCenters, j); + + for (int k = 0; k < dimensions; k++) + newCenter->x[k] = Float4ToHalfUnchecked(aggCenter->x[k]); + } + } + + if (normprocinfo != NULL) + { + for (int j = 0; j < numCenters; j++) + { + Datum newCenter = PointerGetDatum(VectorArrayGet(newCenters, j)); + + ApplyNorm(normprocinfo, collation, newCenter, type); + } } /* Step 5 */ @@ -531,6 +604,19 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type) elog(ERROR, "Infinite value detected. Please report a bug."); } } + else if (type == IVFFLAT_TYPE_HALFVEC) + { + HalfVector *vec = (HalfVector *) VectorArrayGet(centers, i); + + for (int j = 0; j < vec->dim; j++) + { + if (HalfIsNan(vec->x[j])) + elog(ERROR, "NaN detected. Please report a bug."); + + if (HalfIsInf(vec->x[j])) + elog(ERROR, "Infinite value detected. Please report a bug."); + } + } else elog(ERROR, "Unsupported type"); } @@ -539,6 +625,8 @@ CheckCenters(Relation index, VectorArray centers, IvfflatType type) /* Fine to sort in-place */ if (type == IVFFLAT_TYPE_VECTOR) qsort(centers->items, centers->length, centers->itemsize, CompareVectors); + else if (type == IVFFLAT_TYPE_HALFVEC) + qsort(centers->items, centers->length, centers->itemsize, CompareHalfVectors); else elog(ERROR, "Unsupported type"); diff --git a/src/ivfscan.c b/src/ivfscan.c index ce80f28..2b847bc 100644 --- a/src/ivfscan.c +++ b/src/ivfscan.c @@ -5,6 +5,7 @@ #include "access/relscan.h" #include "catalog/pg_operator_d.h" #include "catalog/pg_type_d.h" +#include "halfvec.h" #include "lib/pairingheap.h" #include "ivfflat.h" #include "miscadmin.h" @@ -192,6 +193,8 @@ GetScanValue(IndexScanDesc scan) if (type == IVFFLAT_TYPE_VECTOR) value = PointerGetDatum(InitVector(so->dimensions)); + else if (type == IVFFLAT_TYPE_HALFVEC) + value = PointerGetDatum(InitHalfVector(so->dimensions)); else elog(ERROR, "Unsupported type"); } diff --git a/src/ivfutils.c b/src/ivfutils.c index b7efdca..f2c458c 100644 --- a/src/ivfutils.c +++ b/src/ivfutils.c @@ -2,6 +2,7 @@ #include "access/generic_xlog.h" #include "catalog/pg_type.h" +#include "halfvec.h" #include "ivfflat.h" #include "storage/bufmgr.h" #include "vector.h" @@ -77,6 +78,8 @@ IvfflatGetType(Relation index) type = (Form_pg_type) GETSTRUCT(tuple); if (strcmp(NameStr(type->typname), "vector") == 0) result = IVFFLAT_TYPE_VECTOR; + else if (strcmp(NameStr(type->typname), "halfvec") == 0) + result = IVFFLAT_TYPE_HALFVEC; else { ReleaseSysCache(tuple); @@ -113,6 +116,16 @@ IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, IvfflatType ty *value = PointerGetDatum(result); } + else if (type == IVFFLAT_TYPE_HALFVEC) + { + HalfVector *v = DatumGetHalfVector(*value); + HalfVector *result = InitHalfVector(v->dim); + + for (int i = 0; i < v->dim; i++) + result->x[i] = Float4ToHalfUnchecked(HalfToFloat4(v->x[i]) / norm); + + *value = PointerGetDatum(result); + } else elog(ERROR, "Unsupported type"); diff --git a/test/expected/ivfflat_halfvec_cosine.out b/test/expected/ivfflat_halfvec_cosine.out new file mode 100644 index 0000000..6ed48cc --- /dev/null +++ b/test/expected/ivfflat_halfvec_cosine.out @@ -0,0 +1,26 @@ +SET enable_seqscan = off; +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val halfvec_cosine_ops) WITH (lists = 1); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <=> '[3,3,3]'; + val +--------- + [1,1,1] + [1,2,3] + [1,2,4] +(3 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; + count +------- + 3 +(1 row) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::halfvec)) t2; + count +------- + 3 +(1 row) + +DROP TABLE t; diff --git a/test/expected/ivfflat_halfvec_ip.out b/test/expected/ivfflat_halfvec_ip.out new file mode 100644 index 0000000..a2b7c65 --- /dev/null +++ b/test/expected/ivfflat_halfvec_ip.out @@ -0,0 +1,21 @@ +SET enable_seqscan = off; +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val halfvec_ip_ops) WITH (lists = 1); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <#> '[3,3,3]'; + val +--------- + [1,2,4] + [1,2,3] + [1,1,1] + [0,0,0] +(4 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::halfvec)) t2; + count +------- + 4 +(1 row) + +DROP TABLE t; diff --git a/test/expected/ivfflat_halfvec_l2.out b/test/expected/ivfflat_halfvec_l2.out new file mode 100644 index 0000000..4a8e615 --- /dev/null +++ b/test/expected/ivfflat_halfvec_l2.out @@ -0,0 +1,36 @@ +SET enable_seqscan = off; +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val halfvec_l2_ops) WITH (lists = 1); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] + [1,2,4] + [1,1,1] + [0,0,0] +(4 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <-> (SELECT NULL::halfvec)) t2; + count +------- + 4 +(1 row) + +SELECT COUNT(*) FROM t; + count +------- + 5 +(1 row) + +TRUNCATE t; +NOTICE: ivfflat index created with little data +DETAIL: This will cause low recall. +HINT: Drop the index until the table has more data. +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +----- +(0 rows) + +DROP TABLE t; diff --git a/test/sql/ivfflat_halfvec_cosine.sql b/test/sql/ivfflat_halfvec_cosine.sql new file mode 100644 index 0000000..9d9d87f --- /dev/null +++ b/test/sql/ivfflat_halfvec_cosine.sql @@ -0,0 +1,13 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val halfvec_cosine_ops) WITH (lists = 1); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <=> '[3,3,3]'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::halfvec)) t2; + +DROP TABLE t; diff --git a/test/sql/ivfflat_halfvec_ip.sql b/test/sql/ivfflat_halfvec_ip.sql new file mode 100644 index 0000000..535e9bf --- /dev/null +++ b/test/sql/ivfflat_halfvec_ip.sql @@ -0,0 +1,12 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val halfvec_ip_ops) WITH (lists = 1); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <#> '[3,3,3]'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::halfvec)) t2; + +DROP TABLE t; diff --git a/test/sql/ivfflat_halfvec_l2.sql b/test/sql/ivfflat_halfvec_l2.sql new file mode 100644 index 0000000..b2a600e --- /dev/null +++ b/test/sql/ivfflat_halfvec_l2.sql @@ -0,0 +1,16 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val halfvec(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING ivfflat (val halfvec_l2_ops) WITH (lists = 1); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <-> (SELECT NULL::halfvec)) t2; +SELECT COUNT(*) FROM t; + +TRUNCATE t; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + +DROP TABLE t; diff --git a/test/t/032_ivfflat_halfvec_build_recall.pl b/test/t/032_ivfflat_halfvec_build_recall.pl new file mode 100644 index 0000000..368f0ee --- /dev/null +++ b/test/t/032_ivfflat_halfvec_build_recall.pl @@ -0,0 +1,132 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; +my $dim = 10; +my $array_sql = join(",", ('random()') x $dim); + +sub test_recall +{ + my ($probes, $min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = $probes; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan using idx on tst/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = $probes; + SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v halfvec($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my @r = (); + for (1 .. $dim) + { + push(@r, rand()); + } + push(@queries, "[" . join(",", @r) . "]"); +} + +# Check each index type +my @operators = ("<->", "<#>", "<=>"); +my @opclasses = ("halfvec_l2_ops", "halfvec_ip_ops", "halfvec_cosine_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + # Get exact results + @expected = (); + foreach (@queries) + { + my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;"); + push(@expected, $res); + } + + # Build index serially + $node->safe_psql("postgres", qq( + SET max_parallel_maintenance_workers = 0; + CREATE INDEX idx ON tst USING ivfflat (v $opclass); + )); + + # Test approximate results + if ($operator ne "<#>") + { + # TODO Fix test (uniform random vectors all have similar inner product) + test_recall(1, 0.4, $operator); + test_recall(10, 0.95, $operator); + } + # Account for equal distances + test_recall(100, 0.9925, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); + + # Build index in parallel + my ($ret, $stdout, $stderr) = $node->psql("postgres", qq( + SET client_min_messages = DEBUG; + SET min_parallel_table_scan_size = 1; + CREATE INDEX idx ON tst USING ivfflat (v $opclass); + )); + is($ret, 0, $stderr); + like($stderr, qr/using \d+ parallel workers/); + + # Test approximate results + if ($operator ne "<#>") + { + # TODO Fix test (uniform random vectors all have similar inner product) + test_recall(1, 0.4, $operator); + test_recall(10, 0.95, $operator); + } + # Account for equal distances + test_recall(100, 0.9925, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); +} + +done_testing();