mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Updated indexes to use l2_normalize functions
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
|
||||
#include <float.h>
|
||||
|
||||
#include "fmgr.h"
|
||||
#include "vector.h"
|
||||
|
||||
#if defined(__x86_64__) || defined(_M_AMD64)
|
||||
@@ -43,5 +44,6 @@ typedef struct HalfVector
|
||||
|
||||
HalfVector *InitHalfVector(int dim);
|
||||
int halfvec_cmp_internal(HalfVector * a, HalfVector * b);
|
||||
Datum halfvec_l2_normalize(PG_FUNCTION_ARGS);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "access/generic_xlog.h"
|
||||
#include "catalog/pg_type.h"
|
||||
#include "catalog/pg_type_d.h"
|
||||
#include "fmgr.h"
|
||||
#include "halfutils.h"
|
||||
#include "halfvec.h"
|
||||
#include "hnsw.h"
|
||||
@@ -206,27 +207,11 @@ HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type)
|
||||
|
||||
if (norm > 0)
|
||||
{
|
||||
/* TODO Remove vector-specific code */
|
||||
/* TODO Remove type-specific code */
|
||||
if (type == HNSW_TYPE_VECTOR)
|
||||
{
|
||||
Vector *v = DatumGetVector(*value);
|
||||
Vector *result = InitVector(v->dim);
|
||||
|
||||
for (int i = 0; i < v->dim; i++)
|
||||
result->x[i] = v->x[i] / norm;
|
||||
|
||||
*value = PointerGetDatum(result);
|
||||
}
|
||||
*value = DirectFunctionCall1(l2_normalize, *value);
|
||||
else if (type == HNSW_TYPE_HALFVEC)
|
||||
{
|
||||
HalfVector *v = DatumGetHalfVector(*value);
|
||||
HalfVector *result = InitHalfVector(v->dim);
|
||||
|
||||
for (int i = 0; i < v->dim; i++)
|
||||
result->x[i] = Float4ToHalfUnchecked(HalfToFloat4(v->x[i]) / norm);
|
||||
|
||||
*value = PointerGetDatum(result);
|
||||
}
|
||||
*value = DirectFunctionCall1(halfvec_l2_normalize, *value);
|
||||
else if (type == HNSW_TYPE_SPARSEVEC)
|
||||
{
|
||||
SparseVector *v = DatumGetSparseVector(*value);
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "access/generic_xlog.h"
|
||||
#include "catalog/pg_type.h"
|
||||
#include "fmgr.h"
|
||||
#include "halfutils.h"
|
||||
#include "halfvec.h"
|
||||
#include "ivfflat.h"
|
||||
@@ -108,25 +109,9 @@ IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, IvfflatType ty
|
||||
if (norm > 0)
|
||||
{
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
{
|
||||
Vector *v = DatumGetVector(*value);
|
||||
Vector *result = InitVector(v->dim);
|
||||
|
||||
for (int i = 0; i < v->dim; i++)
|
||||
result->x[i] = v->x[i] / norm;
|
||||
|
||||
*value = PointerGetDatum(result);
|
||||
}
|
||||
*value = DirectFunctionCall1(l2_normalize, *value);
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
{
|
||||
HalfVector *v = DatumGetHalfVector(*value);
|
||||
HalfVector *result = InitHalfVector(v->dim);
|
||||
|
||||
for (int i = 0; i < v->dim; i++)
|
||||
result->x[i] = Float4ToHalfUnchecked(HalfToFloat4(v->x[i]) / norm);
|
||||
|
||||
*value = PointerGetDatum(result);
|
||||
}
|
||||
*value = DirectFunctionCall1(halfvec_l2_normalize, *value);
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#ifndef VECTOR_H
|
||||
#define VECTOR_H
|
||||
|
||||
#include "fmgr.h"
|
||||
|
||||
#define VECTOR_MAX_DIM 16000
|
||||
|
||||
#define VECTOR_SIZE(_dim) (offsetof(Vector, x) + sizeof(float)*(_dim))
|
||||
@@ -19,5 +21,6 @@ typedef struct Vector
|
||||
Vector *InitVector(int dim);
|
||||
void PrintVector(char *msg, Vector * vector);
|
||||
int vector_cmp_internal(Vector * a, Vector * b);
|
||||
Datum l2_normalize(PG_FUNCTION_ARGS);
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user