Updated indexes to use l2_normalize functions

This commit is contained in:
Andrew Kane
2024-04-15 13:56:50 -07:00
parent c282627ce5
commit 10dacfd991
4 changed files with 12 additions and 37 deletions

View File

@@ -5,6 +5,7 @@
#include <float.h>
#include "fmgr.h"
#include "vector.h"
#if defined(__x86_64__) || defined(_M_AMD64)
@@ -43,5 +44,6 @@ typedef struct HalfVector
HalfVector *InitHalfVector(int dim);
int halfvec_cmp_internal(HalfVector * a, HalfVector * b);
Datum halfvec_l2_normalize(PG_FUNCTION_ARGS);
#endif

View File

@@ -5,6 +5,7 @@
#include "access/generic_xlog.h"
#include "catalog/pg_type.h"
#include "catalog/pg_type_d.h"
#include "fmgr.h"
#include "halfutils.h"
#include "halfvec.h"
#include "hnsw.h"
@@ -206,27 +207,11 @@ HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type)
if (norm > 0)
{
/* TODO Remove vector-specific code */
/* TODO Remove type-specific code */
if (type == HNSW_TYPE_VECTOR)
{
Vector *v = DatumGetVector(*value);
Vector *result = InitVector(v->dim);
for (int i = 0; i < v->dim; i++)
result->x[i] = v->x[i] / norm;
*value = PointerGetDatum(result);
}
*value = DirectFunctionCall1(l2_normalize, *value);
else if (type == HNSW_TYPE_HALFVEC)
{
HalfVector *v = DatumGetHalfVector(*value);
HalfVector *result = InitHalfVector(v->dim);
for (int i = 0; i < v->dim; i++)
result->x[i] = Float4ToHalfUnchecked(HalfToFloat4(v->x[i]) / norm);
*value = PointerGetDatum(result);
}
*value = DirectFunctionCall1(halfvec_l2_normalize, *value);
else if (type == HNSW_TYPE_SPARSEVEC)
{
SparseVector *v = DatumGetSparseVector(*value);

View File

@@ -2,6 +2,7 @@
#include "access/generic_xlog.h"
#include "catalog/pg_type.h"
#include "fmgr.h"
#include "halfutils.h"
#include "halfvec.h"
#include "ivfflat.h"
@@ -108,25 +109,9 @@ IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, IvfflatType ty
if (norm > 0)
{
if (type == IVFFLAT_TYPE_VECTOR)
{
Vector *v = DatumGetVector(*value);
Vector *result = InitVector(v->dim);
for (int i = 0; i < v->dim; i++)
result->x[i] = v->x[i] / norm;
*value = PointerGetDatum(result);
}
*value = DirectFunctionCall1(l2_normalize, *value);
else if (type == IVFFLAT_TYPE_HALFVEC)
{
HalfVector *v = DatumGetHalfVector(*value);
HalfVector *result = InitHalfVector(v->dim);
for (int i = 0; i < v->dim; i++)
result->x[i] = Float4ToHalfUnchecked(HalfToFloat4(v->x[i]) / norm);
*value = PointerGetDatum(result);
}
*value = DirectFunctionCall1(halfvec_l2_normalize, *value);
else
elog(ERROR, "Unsupported type");

View File

@@ -1,6 +1,8 @@
#ifndef VECTOR_H
#define VECTOR_H
#include "fmgr.h"
#define VECTOR_MAX_DIM 16000
#define VECTOR_SIZE(_dim) (offsetof(Vector, x) + sizeof(float)*(_dim))
@@ -19,5 +21,6 @@ typedef struct Vector
Vector *InitVector(int dim);
void PrintVector(char *msg, Vector * vector);
int vector_cmp_internal(Vector * a, Vector * b);
Datum l2_normalize(PG_FUNCTION_ARGS);
#endif