diff --git a/src/halfvec.h b/src/halfvec.h index 99a4a99..42b6dce 100644 --- a/src/halfvec.h +++ b/src/halfvec.h @@ -5,6 +5,7 @@ #include +#include "fmgr.h" #include "vector.h" #if defined(__x86_64__) || defined(_M_AMD64) @@ -43,5 +44,6 @@ typedef struct HalfVector HalfVector *InitHalfVector(int dim); int halfvec_cmp_internal(HalfVector * a, HalfVector * b); +Datum halfvec_l2_normalize(PG_FUNCTION_ARGS); #endif diff --git a/src/hnswutils.c b/src/hnswutils.c index 16f7357..5071dfa 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -5,6 +5,7 @@ #include "access/generic_xlog.h" #include "catalog/pg_type.h" #include "catalog/pg_type_d.h" +#include "fmgr.h" #include "halfutils.h" #include "halfvec.h" #include "hnsw.h" @@ -206,27 +207,11 @@ HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type) if (norm > 0) { - /* TODO Remove vector-specific code */ + /* TODO Remove type-specific code */ if (type == HNSW_TYPE_VECTOR) - { - Vector *v = DatumGetVector(*value); - Vector *result = InitVector(v->dim); - - for (int i = 0; i < v->dim; i++) - result->x[i] = v->x[i] / norm; - - *value = PointerGetDatum(result); - } + *value = DirectFunctionCall1(l2_normalize, *value); else if (type == HNSW_TYPE_HALFVEC) - { - HalfVector *v = DatumGetHalfVector(*value); - HalfVector *result = InitHalfVector(v->dim); - - for (int i = 0; i < v->dim; i++) - result->x[i] = Float4ToHalfUnchecked(HalfToFloat4(v->x[i]) / norm); - - *value = PointerGetDatum(result); - } + *value = DirectFunctionCall1(halfvec_l2_normalize, *value); else if (type == HNSW_TYPE_SPARSEVEC) { SparseVector *v = DatumGetSparseVector(*value); diff --git a/src/ivfutils.c b/src/ivfutils.c index e0c9d0a..6d41647 100644 --- a/src/ivfutils.c +++ b/src/ivfutils.c @@ -2,6 +2,7 @@ #include "access/generic_xlog.h" #include "catalog/pg_type.h" +#include "fmgr.h" #include "halfutils.h" #include "halfvec.h" #include "ivfflat.h" @@ -108,25 +109,9 @@ IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, IvfflatType ty if (norm > 0) { if (type == IVFFLAT_TYPE_VECTOR) - { - Vector *v = DatumGetVector(*value); - Vector *result = InitVector(v->dim); - - for (int i = 0; i < v->dim; i++) - result->x[i] = v->x[i] / norm; - - *value = PointerGetDatum(result); - } + *value = DirectFunctionCall1(l2_normalize, *value); else if (type == IVFFLAT_TYPE_HALFVEC) - { - HalfVector *v = DatumGetHalfVector(*value); - HalfVector *result = InitHalfVector(v->dim); - - for (int i = 0; i < v->dim; i++) - result->x[i] = Float4ToHalfUnchecked(HalfToFloat4(v->x[i]) / norm); - - *value = PointerGetDatum(result); - } + *value = DirectFunctionCall1(halfvec_l2_normalize, *value); else elog(ERROR, "Unsupported type"); diff --git a/src/vector.h b/src/vector.h index e649471..4742c37 100644 --- a/src/vector.h +++ b/src/vector.h @@ -1,6 +1,8 @@ #ifndef VECTOR_H #define VECTOR_H +#include "fmgr.h" + #define VECTOR_MAX_DIM 16000 #define VECTOR_SIZE(_dim) (offsetof(Vector, x) + sizeof(float)*(_dim)) @@ -19,5 +21,6 @@ typedef struct Vector Vector *InitVector(int dim); void PrintVector(char *msg, Vector * vector); int vector_cmp_internal(Vector * a, Vector * b); +Datum l2_normalize(PG_FUNCTION_ARGS); #endif