Updated l2_normalize to remove zeros for sparsevec

This commit is contained in:
Andrew Kane
2024-04-15 14:42:53 -07:00
parent cadfc72b75
commit a415420a1c
2 changed files with 31 additions and 7 deletions

View File

@@ -873,19 +873,43 @@ sparsevec_l2_normalize(PG_FUNCTION_ARGS)
/* Return zero vector for zero norm */
if (norm > 0)
{
int zeros = 0;
for (int i = 0; i < a->nnz; i++)
{
result->indices[i] = a->indices[i];
/* TODO Remove zeros */
rx[i] = ax[i] / norm;
}
/* Check for overflow */
for (int i = 0; i < a->nnz; i++)
{
if (isinf(rx[i]))
float_overflow_error();
if (rx[i] == 0)
zeros++;
}
if (zeros > 0)
{
SparseVector *newResult = InitSparseVector(result->dim, result->nnz - zeros);
float *nx = SPARSEVEC_VALUES(newResult);
int j = 0;
for (int i = 0; i < result->nnz; i++)
{
if (rx[i] == 0)
continue;
newResult->indices[j] = result->indices[i];
nx[j] = rx[i];
j++;
/* Safety check */
if (j == newResult->nnz)
break;
}
pfree(result);
PG_RETURN_POINTER(newResult);
}
}