Removed type-specific code from HNSW [skip ci]

This commit is contained in:
Andrew Kane
2024-04-24 14:53:45 -07:00
parent b8bdf317f0
commit 3eef1ff5c2
6 changed files with 64 additions and 83 deletions

View File

@@ -327,7 +327,7 @@ CREATE FUNCTION jaccard_distance(bit, bit) RETURNS float8
CREATE FUNCTION bit_ivfflat_support(internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION bit_hnsw_support(internal) RETURNS internal
CREATE FUNCTION bit_hnsw_max_dims(internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
-- bit operators
@@ -355,13 +355,13 @@ CREATE OPERATOR CLASS bit_hamming_ops
FOR TYPE bit USING hnsw AS
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
FUNCTION 1 hamming_distance(bit, bit),
FUNCTION 4 bit_hnsw_support(internal);
FUNCTION 4 bit_hnsw_max_dims(internal);
CREATE OPERATOR CLASS bit_jaccard_ops
FOR TYPE bit USING hnsw AS
OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops,
FUNCTION 1 jaccard_distance(bit, bit),
FUNCTION 4 bit_hnsw_support(internal);
FUNCTION 4 bit_hnsw_max_dims(internal);
-- halfvec type
@@ -473,7 +473,7 @@ CREATE FUNCTION halfvec_avg(double precision[]) RETURNS halfvec
CREATE FUNCTION halfvec_ivfflat_support(internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION halfvec_hnsw_support(internal) RETURNS internal
CREATE FUNCTION halfvec_hnsw_max_dims(internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
-- halfvec aggregates
@@ -663,13 +663,13 @@ CREATE OPERATOR CLASS halfvec_l2_ops
FOR TYPE halfvec USING hnsw AS
OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops,
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
FUNCTION 4 halfvec_hnsw_support(internal);
FUNCTION 4 halfvec_hnsw_max_dims(internal);
CREATE OPERATOR CLASS halfvec_ip_ops
FOR TYPE halfvec USING hnsw AS
OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops,
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
FUNCTION 4 halfvec_hnsw_support(internal);
FUNCTION 4 halfvec_hnsw_max_dims(internal);
CREATE OPERATOR CLASS halfvec_cosine_ops
FOR TYPE halfvec USING hnsw AS
@@ -677,13 +677,13 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
FUNCTION 2 l2_norm(halfvec),
FUNCTION 3 l2_normalize(halfvec),
FUNCTION 4 halfvec_hnsw_support(internal);
FUNCTION 4 halfvec_hnsw_max_dims(internal);
CREATE OPERATOR CLASS halfvec_l1_ops
FOR TYPE halfvec USING hnsw AS
OPERATOR 1 <+> (halfvec, halfvec) FOR ORDER BY float_ops,
FUNCTION 1 l1_distance(halfvec, halfvec),
FUNCTION 4 halfvec_hnsw_support(internal);
FUNCTION 4 halfvec_hnsw_max_dims(internal);
--- sparsevec type
@@ -779,7 +779,10 @@ CREATE FUNCTION halfvec_to_sparsevec(halfvec, integer, boolean) RETURNS sparseve
CREATE FUNCTION sparsevec_to_halfvec(sparsevec, integer, boolean) RETURNS halfvec
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_hnsw_support(internal) RETURNS internal
CREATE FUNCTION sparsevec_hnsw_max_dims(internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION sparsevec_hnsw_check_value(internal) RETURNS internal
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
-- sparsevec casts
@@ -872,13 +875,15 @@ CREATE OPERATOR CLASS sparsevec_l2_ops
FOR TYPE sparsevec USING hnsw AS
OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops,
FUNCTION 1 sparsevec_l2_squared_distance(sparsevec, sparsevec),
FUNCTION 4 sparsevec_hnsw_support(internal);
FUNCTION 4 sparsevec_hnsw_max_dims(internal),
FUNCTION 5 sparsevec_hnsw_check_value(internal);
CREATE OPERATOR CLASS sparsevec_ip_ops
FOR TYPE sparsevec USING hnsw AS
OPERATOR 1 <#> (sparsevec, sparsevec) FOR ORDER BY float_ops,
FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec),
FUNCTION 4 sparsevec_hnsw_support(internal);
FUNCTION 4 sparsevec_hnsw_max_dims(internal),
FUNCTION 5 sparsevec_hnsw_check_value(internal);
CREATE OPERATOR CLASS sparsevec_cosine_ops
FOR TYPE sparsevec USING hnsw AS
@@ -886,10 +891,12 @@ CREATE OPERATOR CLASS sparsevec_cosine_ops
FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec),
FUNCTION 2 l2_norm(sparsevec),
FUNCTION 3 l2_normalize(sparsevec),
FUNCTION 4 sparsevec_hnsw_support(internal);
FUNCTION 4 sparsevec_hnsw_max_dims(internal),
FUNCTION 5 sparsevec_hnsw_check_value(internal);
CREATE OPERATOR CLASS sparsevec_l1_ops
FOR TYPE sparsevec USING hnsw AS
OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops,
FUNCTION 1 l1_distance(sparsevec, sparsevec),
FUNCTION 4 sparsevec_hnsw_support(internal);
FUNCTION 4 sparsevec_hnsw_max_dims(internal),
FUNCTION 5 sparsevec_hnsw_check_value(internal);

View File

@@ -194,7 +194,7 @@ hnswhandler(PG_FUNCTION_ARGS)
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
amroutine->amstrategies = 0;
amroutine->amsupport = 4;
amroutine->amsupport = 5;
#if PG_VERSION_NUM >= 130000
amroutine->amoptsprocnum = 0;
#endif

View File

@@ -23,7 +23,8 @@
#define HNSW_DISTANCE_PROC 1
#define HNSW_NORM_PROC 2
#define HNSW_NORMALIZE_PROC 3
#define HNSW_TYPE_SUPPORT_PROC 4
#define HNSW_MAX_DIMS_PROC 4
#define HNSW_CHECK_VALUE_PROC 5
#define HNSW_VERSION 1
#define HNSW_MAGIC_NUMBER 0xA953A953
@@ -58,15 +59,6 @@
#define HNSW_UPDATE_ENTRY_GREATER 1
#define HNSW_UPDATE_ENTRY_ALWAYS 2
typedef enum HnswType
{
HNSW_TYPE_VECTOR,
HNSW_TYPE_HALFVEC,
HNSW_TYPE_BIT,
HNSW_TYPE_SPARSEVEC,
HNSW_TYPE_UNSUPPORTED
} HnswType;
/* Build phases */
/* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */
#define PROGRESS_HNSW_PHASE_LOAD 2
@@ -254,7 +246,6 @@ typedef struct HnswBuildState
Relation index;
IndexInfo *indexInfo;
ForkNumber forkNum;
HnswType type;
/* Settings */
int dimensions;
@@ -269,6 +260,7 @@ typedef struct HnswBuildState
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
FmgrInfo *normalizeprocinfo;
FmgrInfo *checkvalueprocinfo;
Oid collation;
/* Variables */
@@ -381,10 +373,9 @@ typedef struct HnswVacuumState
int HnswGetM(Relation index);
int HnswGetEfConstruction(Relation index);
FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum);
HnswType HnswGetType(Relation index);
Datum HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum value);
bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
void HnswCheckValue(Datum value, HnswType type);
void HnswCheckValue(FmgrInfo *procinfo, Oid collation, Datum value);
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
void HnswInitPage(Buffer buf, Page page);
void HnswInit(void);

View File

@@ -488,7 +488,8 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
/* Check value */
HnswCheckValue(value, buildstate->type);
if (buildstate->checkvalueprocinfo != NULL)
HnswCheckValue(buildstate->checkvalueprocinfo, buildstate->collation, value);
/* Normalize if needed */
if (buildstate->normprocinfo != NULL)
@@ -675,18 +676,14 @@ HnswSharedMemoryAlloc(Size size, void *state)
* Get max dimensions
*/
static int
GetMaxDimensions(HnswType type)
GetMaxDimensions(Relation index)
{
int maxDimensions = HNSW_MAX_DIM;
FmgrInfo *procinfo = HnswOptionalProcInfo(index, HNSW_MAX_DIMS_PROC);
if (type == HNSW_TYPE_HALFVEC)
maxDimensions *= 2;
else if (type == HNSW_TYPE_BIT)
maxDimensions *= 32;
else if (type == HNSW_TYPE_SPARSEVEC)
maxDimensions = INT_MAX;
if (procinfo == NULL)
return HNSW_MAX_DIM;
return maxDimensions;
return DatumGetInt32(FunctionCall1(procinfo, PointerGetDatum(NULL)));
}
/*
@@ -701,13 +698,16 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
buildstate->index = index;
buildstate->indexInfo = indexInfo;
buildstate->forkNum = forkNum;
buildstate->type = HnswGetType(index);
buildstate->m = HnswGetM(index);
buildstate->efConstruction = HnswGetEfConstruction(index);
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
maxDimensions = GetMaxDimensions(buildstate->type);
/* Disallow varbit since require fixed dimensions */
if (TupleDescAttr(index->rd_att, 0)->atttypid == VARBITOID)
elog(ERROR, "type not supported for hnsw index");
maxDimensions = GetMaxDimensions(index);
/* Require column to have dimensions to be indexed */
if (buildstate->dimensions < 0)
@@ -726,6 +726,7 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
buildstate->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC);
buildstate->checkvalueprocinfo = HnswOptionalProcInfo(index, HNSW_CHECK_VALUE_PROC);
buildstate->collation = index->rd_indcollation[0];
InitGraph(&buildstate->graphData, NULL, maintenance_work_mem * 1024L);

View File

@@ -612,6 +612,7 @@ static void
HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid)
{
Datum value;
FmgrInfo *checkvalueprocinfo = HnswOptionalProcInfo(index, HNSW_CHECK_VALUE_PROC);
FmgrInfo *normprocinfo;
Oid collation = index->rd_indcollation[0];
@@ -619,7 +620,8 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti
value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
/* Check value */
HnswCheckValue(value, HnswGetType(index));
if (checkvalueprocinfo != NULL)
HnswCheckValue(checkvalueprocinfo, collation, value);
/* Normalize if needed */
normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);

View File

@@ -152,27 +152,6 @@ HnswOptionalProcInfo(Relation index, uint16 procnum)
return index_getprocinfo(index, 1, procnum);
}
/*
* Get type
*/
HnswType
HnswGetType(Relation index)
{
FmgrInfo *procinfo = HnswOptionalProcInfo(index, HNSW_TYPE_SUPPORT_PROC);
Oid typid = TupleDescAttr(index->rd_att, 0)->atttypid;
HnswType result;
if (procinfo == NULL)
return HNSW_TYPE_VECTOR;
result = (HnswType) DatumGetInt32(FunctionCall1(procinfo, ObjectIdGetDatum(typid)));
if (result == HNSW_TYPE_UNSUPPORTED)
elog(ERROR, "type not supported for hnsw index");
return result;
}
/*
* Normalize value
*/
@@ -198,15 +177,9 @@ HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value)
* Check if a value can be indexed
*/
void
HnswCheckValue(Datum value, HnswType type)
HnswCheckValue(FmgrInfo *procinfo, Oid collation, Datum value)
{
if (type == HNSW_TYPE_SPARSEVEC)
{
SparseVector *vec = DatumGetSparseVector(value);
if (vec->nnz > HNSW_MAX_NNZ)
elog(ERROR, "sparsevec cannot have more than %d non-zero elements for hnsw index", HNSW_MAX_NNZ);
}
FunctionCall1Coll(procinfo, collation, value);
}
/*
@@ -1303,28 +1276,35 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint
}
}
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_hnsw_support);
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_hnsw_max_dims);
Datum
halfvec_hnsw_support(PG_FUNCTION_ARGS)
halfvec_hnsw_max_dims(PG_FUNCTION_ARGS)
{
PG_RETURN_INT32(HNSW_TYPE_HALFVEC);
PG_RETURN_INT32(HNSW_MAX_DIM * 2);
};
PGDLLEXPORT PG_FUNCTION_INFO_V1(bit_hnsw_support);
PGDLLEXPORT PG_FUNCTION_INFO_V1(bit_hnsw_max_dims);
Datum
bit_hnsw_support(PG_FUNCTION_ARGS)
bit_hnsw_max_dims(PG_FUNCTION_ARGS)
{
Oid typid = PG_GETARG_OID(0);
if (typid == BITOID)
PG_RETURN_INT32(HNSW_TYPE_BIT);
else
PG_RETURN_INT32(HNSW_TYPE_UNSUPPORTED);
PG_RETURN_INT32(HNSW_MAX_DIM * 32);
};
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_hnsw_support);
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_hnsw_max_dims);
Datum
sparsevec_hnsw_support(PG_FUNCTION_ARGS)
sparsevec_hnsw_max_dims(PG_FUNCTION_ARGS)
{
PG_RETURN_INT32(HNSW_TYPE_SPARSEVEC);
PG_RETURN_INT32(SPARSEVEC_MAX_DIM);
};
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_hnsw_check_value);
Datum
sparsevec_hnsw_check_value(PG_FUNCTION_ARGS)
{
SparseVector *vec = PG_GETARG_SPARSEVEC_P(0);
if (vec->nnz > HNSW_MAX_NNZ)
elog(ERROR, "sparsevec cannot have more than %d non-zero elements for hnsw index", HNSW_MAX_NNZ);
PG_RETURN_VOID();
}