From a415420a1c8c3baa6ee1bbd693df30a2e0b018b0 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 15 Apr 2024 14:42:53 -0700 Subject: [PATCH] Updated l2_normalize to remove zeros for sparsevec --- src/sparsevec.c | 36 ++++++++++++++++++++++----- test/expected/sparsevec_functions.out | 2 +- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/sparsevec.c b/src/sparsevec.c index 5bed3ff..f923ad5 100644 --- a/src/sparsevec.c +++ b/src/sparsevec.c @@ -873,19 +873,43 @@ sparsevec_l2_normalize(PG_FUNCTION_ARGS) /* Return zero vector for zero norm */ if (norm > 0) { + int zeros = 0; + for (int i = 0; i < a->nnz; i++) { result->indices[i] = a->indices[i]; - - /* TODO Remove zeros */ rx[i] = ax[i] / norm; - } - /* Check for overflow */ - for (int i = 0; i < a->nnz; i++) - { if (isinf(rx[i])) float_overflow_error(); + + if (rx[i] == 0) + zeros++; + } + + if (zeros > 0) + { + SparseVector *newResult = InitSparseVector(result->dim, result->nnz - zeros); + float *nx = SPARSEVEC_VALUES(newResult); + int j = 0; + + for (int i = 0; i < result->nnz; i++) + { + if (rx[i] == 0) + continue; + + newResult->indices[j] = result->indices[i]; + nx[j] = rx[i]; + j++; + + /* Safety check */ + if (j == newResult->nnz) + break; + } + + pfree(result); + + PG_RETURN_POINTER(newResult); } } diff --git a/test/expected/sparsevec_functions.out b/test/expected/sparsevec_functions.out index e6af3c9..099bdfd 100644 --- a/test/expected/sparsevec_functions.out +++ b/test/expected/sparsevec_functions.out @@ -325,6 +325,6 @@ SELECT l2_normalize('{1:3e38}/1'::sparsevec); SELECT l2_normalize('{1:3e38,2:1e-37}/2'::sparsevec); l2_normalize -------------- - {1:1,2:0}/2 + {1:1}/2 (1 row)