diff --git a/src/sparsevec.c b/src/sparsevec.c index 270b049..2d41214 100644 --- a/src/sparsevec.c +++ b/src/sparsevec.c @@ -943,10 +943,10 @@ sparsevec_cmp_internal(SparseVector * a, SparseVector * b) return 1; } - if (a->nnz < b->nnz) + if (a->nnz < b->nnz && b->indices[nnz] <= a->dim) return bx[nnz] < 0 ? 1 : -1; - if (a->nnz > b->nnz) + if (a->nnz > b->nnz && a->indices[nnz] <= b->dim) return ax[nnz] < 0 ? -1 : 1; if (a->dim < b->dim) diff --git a/test/t/033_comparison.pl b/test/t/033_comparison.pl new file mode 100644 index 0000000..a30043b --- /dev/null +++ b/test/t/033_comparison.pl @@ -0,0 +1,48 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my $array_sql = join(",", ('floor(random() * 2)::int - 1') x 3); + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (v real[]);"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT ARRAY[$array_sql] FROM generate_series(1, 10000) i;" +); + +for (1 .. 50) +{ + # Generate queries + my @r = (); + for (1 .. (int(rand() * 3) + 2)) + { + push(@r, int(rand() * 2) - 1); + } + my $query = "{" . join(",", @r) . "}"; + + # Get expected results + my $expected = $node->safe_psql("postgres", "SELECT btarraycmp(v, '$query') FROM tst"); + + # Test vector + my $actual = $node->safe_psql("postgres", "SELECT vector_cmp(v::vector, '$query'::real[]::vector) FROM tst"); + is($expected, $actual); + + # Test halfvec + $actual = $node->safe_psql("postgres", "SELECT halfvec_cmp(v::halfvec, '$query'::real[]::halfvec) FROM tst"); + is($expected, $actual); + + # Test sparsevec + $actual = $node->safe_psql("postgres", "SELECT sparsevec_cmp(v::vector::sparsevec, '$query'::real[]::vector::sparsevec) FROM tst"); + is($expected, $actual); +} + +done_testing();