Added basic support for float4 arrays

This commit is contained in:
Andrew Kane
2023-09-13 13:41:06 -07:00
parent 310a880186
commit 72e9cf06c1
9 changed files with 200 additions and 3 deletions

58
src/float4.c Normal file
View File

@@ -0,0 +1,58 @@
#include "postgres.h"
#include "utils/array.h"
/*
* Get the L2 distance between vectors
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(float4_l2_distance);
Datum
float4_l2_distance(PG_FUNCTION_ARGS)
{
ArrayType *a = PG_GETARG_ARRAYTYPE_P(0);
ArrayType *b = PG_GETARG_ARRAYTYPE_P(1);
float *ax = (float *) ARR_DATA_PTR(a);
float *bx = (float *) ARR_DATA_PTR(b);
float distance = 0.0;
float diff;
/* TODO Check rank, dimensions, and nulls */
int dim = ARR_DIMS(a)[0];
/* Auto-vectorized */
for (int i = 0; i < dim; i++)
{
diff = ax[i] - bx[i];
distance += diff * diff;
}
PG_RETURN_FLOAT8(sqrt((double) distance));
}
/*
* Get the L2 squared distance between vectors
* This saves a sqrt calculation
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(float4_l2_squared_distance);
Datum
float4_l2_squared_distance(PG_FUNCTION_ARGS)
{
ArrayType *a = PG_GETARG_ARRAYTYPE_P(0);
ArrayType *b = PG_GETARG_ARRAYTYPE_P(1);
float *ax = (float *) ARR_DATA_PTR(a);
float *bx = (float *) ARR_DATA_PTR(b);
float distance = 0.0;
float diff;
/* TODO Check rank, dimensions, and nulls */
int dim = ARR_DIMS(a)[0];
/* Auto-vectorized */
for (int i = 0; i < dim; i++)
{
diff = ax[i] - bx[i];
distance += diff * diff;
}
PG_RETURN_FLOAT8((double) distance);
}

View File

@@ -33,6 +33,12 @@ HnswInit(void)
HNSW_DEFAULT_EF_CONSTRUCTION, HNSW_MIN_EF_CONSTRUCTION, HNSW_MAX_EF_CONSTRUCTION
#if PG_VERSION_NUM >= 130000
,AccessExclusiveLock
#endif
);
add_int_reloption(hnsw_relopt_kind, "dimensions", "Number of dimensions",
HNSW_DEFAULT_DIMENSIONS, HNSW_MIN_DIMENSIONS, HNSW_MAX_DIMENSIONS
#if PG_VERSION_NUM >= 130000
,AccessExclusiveLock
#endif
);
@@ -125,6 +131,7 @@ hnswoptions(Datum reloptions, bool validate)
static const relopt_parse_elt tab[] = {
{"m", RELOPT_TYPE_INT, offsetof(HnswOptions, m)},
{"ef_construction", RELOPT_TYPE_INT, offsetof(HnswOptions, efConstruction)},
{"dimensions", RELOPT_TYPE_INT, offsetof(HnswOptions, dimensions)},
};
#if PG_VERSION_NUM >= 130000

View File

@@ -42,6 +42,9 @@
#define HNSW_DEFAULT_EF_SEARCH 40
#define HNSW_MIN_EF_SEARCH 1
#define HNSW_MAX_EF_SEARCH 1000
#define HNSW_DEFAULT_DIMENSIONS -1
#define HNSW_MIN_DIMENSIONS 1
#define HNSW_MAX_DIMENSIONS HNSW_MAX_DIM
/* Tuple types */
#define HNSW_ELEMENT_TUPLE_TYPE 1
@@ -131,6 +134,7 @@ typedef struct HnswOptions
int32 vl_len_; /* varlena header (do not touch directly!) */
int m; /* number of connections */
int efConstruction; /* size of dynamic candidate list */
int dimensions;
} HnswOptions;
typedef struct HnswBuildState
@@ -259,6 +263,7 @@ typedef struct HnswVacuumState
/* Methods */
int HnswGetM(Relation index);
int HnswGetEfConstruction(Relation index);
int HnswGetDimensions(Relation index);
FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum);
bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result);
void HnswCommitBuffer(Buffer buf, GenericXLogState *state);

View File

@@ -400,6 +400,7 @@ HnswGetMaxInMemoryElements(int m, double ml, int dimensions)
elementSize += sizeof(HnswNeighborArray) * (avgLevel + 1);
elementSize += sizeof(HnswCandidate) * (m * (avgLevel + 2));
elementSize += sizeof(ItemPointerData);
/* TODO Handle non-vector types */
elementSize += VECTOR_SIZE(dimensions);
return (maintenance_work_mem * 1024L) / elementSize;
}
@@ -417,7 +418,10 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
buildstate->m = HnswGetM(index);
buildstate->efConstruction = HnswGetEfConstruction(index);
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
buildstate->dimensions = HnswGetDimensions(index);
if (buildstate->dimensions < 0)
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
/* Require column to have dimensions to be indexed */
if (buildstate->dimensions < 0)

View File

@@ -35,6 +35,20 @@ HnswGetEfConstruction(Relation index)
return HNSW_DEFAULT_EF_CONSTRUCTION;
}
/*
* Get the number of dimensions in the index
*/
int
HnswGetDimensions(Relation index)
{
HnswOptions *opts = (HnswOptions *) index->rd_options;
if (opts)
return opts->dimensions;
return HNSW_DEFAULT_DIMENSIONS;
}
/*
* Get proc
*/