From 116501f062de5fe5839c0bd4dc3dcb181e281f11 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Thu, 10 Oct 2024 17:07:47 -0700 Subject: [PATCH] Added support for inline filtering with IVFFlat --- CHANGELOG.md | 1 + README.md | 6 + sql/vector--0.7.4--0.8.0.sql | 2 + sql/vector.sql | 10 ++ src/ivfbuild.c | 13 ++ src/ivfflat.c | 2 +- src/ivfscan.c | 32 ++++ test/t/041_ivfflat_inline_filtering.pl | 197 +++++++++++++++++++++++++ 8 files changed, 262 insertions(+), 1 deletion(-) create mode 100644 test/t/041_ivfflat_inline_filtering.pl diff --git a/CHANGELOG.md b/CHANGELOG.md index a7d9924..e68efab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ ## 0.8.0 (unreleased) +- Added support for inline filtering with IVFFlat - Added casts for arrays to `sparsevec` - Improved cost estimation - Improved performance of HNSW inserts and on-disk index builds diff --git a/README.md b/README.md index 6cdca4f..9781c1b 100644 --- a/README.md +++ b/README.md @@ -439,6 +439,12 @@ Create an index on one [or more](https://www.postgresql.org/docs/current/indexes CREATE INDEX ON items (category_id); ``` +Or a composite IVFFlat index for approximate search (added in 0.8.0) + +```sql +CREATE INDEX ON items USING ivfflat (embedding vector_l2_ops, category_id) WITH (lists = 100); +``` + Or a [partial index](https://www.postgresql.org/docs/current/indexes-partial.html) on the vector column for approximate search ```sql diff --git a/sql/vector--0.7.4--0.8.0.sql b/sql/vector--0.7.4--0.8.0.sql index e00348d..63acffe 100644 --- a/sql/vector--0.7.4--0.8.0.sql +++ b/sql/vector--0.7.4--0.8.0.sql @@ -24,3 +24,5 @@ CREATE CAST (double precision[] AS sparsevec) CREATE CAST (numeric[] AS sparsevec) WITH FUNCTION array_to_sparsevec(numeric[], integer, boolean) AS ASSIGNMENT; + +-- TODO add ivfflat attributes diff --git a/sql/vector.sql b/sql/vector.sql index 7fc3671..4ddd572 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -916,3 +916,13 @@ CREATE OPERATOR CLASS sparsevec_l1_ops OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops, FUNCTION 1 l1_distance(sparsevec, sparsevec), FUNCTION 3 hnsw_sparsevec_support(internal); + +-- ivfflat attributes + +CREATE OPERATOR CLASS vector_integer_ops + DEFAULT FOR TYPE integer USING ivfflat AS + OPERATOR 2 < , + OPERATOR 3 <= , + OPERATOR 4 = , + OPERATOR 5 >= , + OPERATOR 6 > ; diff --git a/src/ivfbuild.c b/src/ivfbuild.c index e88a0bb..f67cfef 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -341,6 +341,19 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), errmsg("type not supported for ivfflat index"))); + /* TODO See if needed */ + if (IndexRelationGetNumberOfKeyAttributes(index) > 3) + elog(ERROR, "index cannot have more than three columns"); + + if (!OidIsValid(index_getprocid(index, 1, IVFFLAT_DISTANCE_PROC))) + elog(ERROR, "first column must be a vector"); + + for (int i = 1; i < IndexRelationGetNumberOfKeyAttributes(index); i++) + { + if (OidIsValid(index_getprocid(index, i + 1, IVFFLAT_DISTANCE_PROC))) + elog(ERROR, "column %d cannot be a vector", i + 1); + } + /* Require column to have dimensions to be indexed */ if (buildstate->dimensions < 0) ereport(ERROR, diff --git a/src/ivfflat.c b/src/ivfflat.c index 395040d..31501b5 100644 --- a/src/ivfflat.c +++ b/src/ivfflat.c @@ -167,7 +167,7 @@ ivfflathandler(PG_FUNCTION_ARGS) amroutine->amcanorderbyop = true; amroutine->amcanbackward = false; /* can change direction mid-scan */ amroutine->amcanunique = false; - amroutine->amcanmulticol = false; + amroutine->amcanmulticol = true; amroutine->amoptionalkey = true; amroutine->amsearcharray = false; amroutine->amsearchnulls = false; diff --git a/src/ivfscan.c b/src/ivfscan.c index 74e3675..d1ff6e2 100644 --- a/src/ivfscan.c +++ b/src/ivfscan.c @@ -104,6 +104,34 @@ GetScanLists(IndexScanDesc scan, Datum value) } } +/* + * Check if matches scan keys + */ +static bool +MatchesScanKeys(IndexScanDesc scan, IndexTuple itup, TupleDesc tupdesc) +{ + for (int i = 0; i < scan->numberOfKeys; i++) + { + ScanKey key = &scan->keyData[i]; + bool attnull = key->sk_flags & SK_ISNULL; + bool isnull; + Datum value = index_getattr(itup, key->sk_attno, tupdesc, &isnull); + + if (isnull || attnull) + { + if (isnull != attnull) + return false; + } + else + { + if (!DatumGetBool(FunctionCall2Coll(&key->sk_func, key->sk_collation, value, key->sk_argument))) + return false; + } + } + + return true; +} + /* * Get items */ @@ -140,6 +168,10 @@ GetScanItems(IndexScanDesc scan, Datum value) ItemId itemid = PageGetItemId(page, offno); itup = (IndexTuple) PageGetItem(page, itemid); + + if (!MatchesScanKeys(scan, itup, tupdesc)) + continue; + datum = index_getattr(itup, 1, tupdesc, &isnull); /* diff --git a/test/t/041_ivfflat_inline_filtering.pl b/test/t/041_ivfflat_inline_filtering.pl new file mode 100644 index 0000000..0b68d3a --- /dev/null +++ b/test/t/041_ivfflat_inline_filtering.pl @@ -0,0 +1,197 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $node; +my @queries = (); +my @where = (); +my @expected; +my $limit = 20; +my $dim = 3; +my $array_sql = join(",", ('random()') x $dim); +my $nc = 100; +my $nc2 = 10; + +sub test_recall +{ + my ($probes, $min, $operator) = @_; + my $correct = 0; + my $total = 0; + + for my $j (0 .. 2) + { + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = $probes; + EXPLAIN ANALYZE SELECT i FROM tst WHERE $where[$j] ORDER BY v $operator '$queries[$j]' LIMIT $limit; + )); + like($explain, qr/Index Cond/); + } + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = $probes; + SELECT i FROM tst WHERE $where[$i] ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + is(scalar(@actual_ids), $limit); + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim), c int4, c2 int4);"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql], i % $nc, i % $nc2 FROM generate_series(1, 50000) i;" +); + +# Generate queries +for my $i (1 .. 100) +{ + my @r = (); + for (1 .. $dim) + { + push(@r, rand()); + } + push(@queries, "[" . join(",", @r) . "]"); + + if ($i % 3 == 0) + { + my $c = int(rand() * $nc); + push(@where, "c = $c"); + } + elsif ($i % 3 == 1) + { + my $c2 = int(rand() * $nc2); + push(@where, "c2 = $c2"); + } + else + { + # use c2 to ensure results + my $c2 = int(rand() * $nc2); + push(@where, "c = $c2 AND c2 = $c2"); + } +} + +# Add index +$node->safe_psql("postgres", qq( + CREATE INDEX ON tst USING ivfflat (v vector_l2_ops, c, c2) WITH (lists = 100); +)); + +# Insert more rows +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql], i % $nc, i % $nc2 FROM generate_series(1, 50000) i;" +); + +# Get exact results +@expected = (); +for my $i (0 .. $#queries) +{ + my $res = $node->safe_psql("postgres", qq( + SET enable_indexscan = off; + SELECT i FROM tst WHERE $where[$i] ORDER BY v <-> '$queries[$i]' LIMIT $limit; + )); + push(@expected, $res); +} + +# Test recall +test_recall(10, 0.99, '<->'); + +# Test vacuum +$node->safe_psql("postgres", "DELETE FROM tst WHERE c > 5;"); +$node->safe_psql("postgres", "VACUUM tst;"); + +# Test less than +my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst WHERE c < 10 ORDER BY v <-> '$queries[0]' LIMIT $limit; +)); +like($explain, qr/Index Cond: \(c < 10\)/); + +# Test less than or equal +$explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst WHERE c <= 10 ORDER BY v <-> '$queries[0]' LIMIT $limit; +)); +like($explain, qr/Index Cond: \(c <= 10\)/); + +# Test greater than or equal +$explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst WHERE c >= 90 ORDER BY v <-> '$queries[0]' LIMIT $limit; +)); +like($explain, qr/Index Cond: \(c >= 90\)/); + +# Test greater than +$explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst WHERE c > 90 ORDER BY v <-> '$queries[0]' LIMIT $limit; +)); +like($explain, qr/Index Cond: \(c > 90\)/); + +# Test multiple attribute columns +$explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst WHERE c = 1 AND c2 = 1 ORDER BY v <-> '$queries[0]' LIMIT $limit; +)); +like($explain, qr/Index Cond: \(\(c = 1\) AND \(c2 = 1\)\)/); + +# Test only last attribute column +$explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst WHERE c2 = 1 ORDER BY v <-> '$queries[0]' LIMIT $limit; +)); +like($explain, qr/Index Cond: \(c2 = 1\)/); + +# Test only vector column +$explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v <-> '$queries[0]' LIMIT $limit; +)); +like($explain, qr/Index Scan/); + +# Test only attribute columns +$explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst WHERE c = 1; +)); +like($explain, qr/Seq Scan/); + +# Test columns +my ($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING ivfflat (c);"); +like($stderr, qr/first column must be a vector/); + +($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING ivfflat (c, v vector_cosine_ops);"); +like($stderr, qr/first column must be a vector/); + +($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING ivfflat (v vector_cosine_ops, c, c, c);"); +like($stderr, qr/index cannot have more than three columns/); + +($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING ivfflat (v vector_cosine_ops, v vector_cosine_ops);"); +like($stderr, qr/column 2 cannot be a vector/); + +done_testing();