mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Removed type-specific code from HNSW [skip ci]
This commit is contained in:
@@ -327,7 +327,7 @@ CREATE FUNCTION jaccard_distance(bit, bit) RETURNS float8
|
||||
CREATE FUNCTION bit_ivfflat_support(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION bit_hnsw_support(internal) RETURNS internal
|
||||
CREATE FUNCTION bit_hnsw_max_dims(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
-- bit operators
|
||||
@@ -355,13 +355,13 @@ CREATE OPERATOR CLASS bit_hamming_ops
|
||||
FOR TYPE bit USING hnsw AS
|
||||
OPERATOR 1 <~> (bit, bit) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 hamming_distance(bit, bit),
|
||||
FUNCTION 4 bit_hnsw_support(internal);
|
||||
FUNCTION 4 bit_hnsw_max_dims(internal);
|
||||
|
||||
CREATE OPERATOR CLASS bit_jaccard_ops
|
||||
FOR TYPE bit USING hnsw AS
|
||||
OPERATOR 1 <%> (bit, bit) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 jaccard_distance(bit, bit),
|
||||
FUNCTION 4 bit_hnsw_support(internal);
|
||||
FUNCTION 4 bit_hnsw_max_dims(internal);
|
||||
|
||||
-- halfvec type
|
||||
|
||||
@@ -473,7 +473,7 @@ CREATE FUNCTION halfvec_avg(double precision[]) RETURNS halfvec
|
||||
CREATE FUNCTION halfvec_ivfflat_support(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION halfvec_hnsw_support(internal) RETURNS internal
|
||||
CREATE FUNCTION halfvec_hnsw_max_dims(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
-- halfvec aggregates
|
||||
@@ -663,13 +663,13 @@ CREATE OPERATOR CLASS halfvec_l2_ops
|
||||
FOR TYPE halfvec USING hnsw AS
|
||||
OPERATOR 1 <-> (halfvec, halfvec) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 halfvec_l2_squared_distance(halfvec, halfvec),
|
||||
FUNCTION 4 halfvec_hnsw_support(internal);
|
||||
FUNCTION 4 halfvec_hnsw_max_dims(internal);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_ip_ops
|
||||
FOR TYPE halfvec USING hnsw AS
|
||||
OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
|
||||
FUNCTION 4 halfvec_hnsw_support(internal);
|
||||
FUNCTION 4 halfvec_hnsw_max_dims(internal);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_cosine_ops
|
||||
FOR TYPE halfvec USING hnsw AS
|
||||
@@ -677,13 +677,13 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
|
||||
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
|
||||
FUNCTION 2 l2_norm(halfvec),
|
||||
FUNCTION 3 l2_normalize(halfvec),
|
||||
FUNCTION 4 halfvec_hnsw_support(internal);
|
||||
FUNCTION 4 halfvec_hnsw_max_dims(internal);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_l1_ops
|
||||
FOR TYPE halfvec USING hnsw AS
|
||||
OPERATOR 1 <+> (halfvec, halfvec) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 l1_distance(halfvec, halfvec),
|
||||
FUNCTION 4 halfvec_hnsw_support(internal);
|
||||
FUNCTION 4 halfvec_hnsw_max_dims(internal);
|
||||
|
||||
--- sparsevec type
|
||||
|
||||
@@ -779,7 +779,10 @@ CREATE FUNCTION halfvec_to_sparsevec(halfvec, integer, boolean) RETURNS sparseve
|
||||
CREATE FUNCTION sparsevec_to_halfvec(sparsevec, integer, boolean) RETURNS halfvec
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION sparsevec_hnsw_support(internal) RETURNS internal
|
||||
CREATE FUNCTION sparsevec_hnsw_max_dims(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
CREATE FUNCTION sparsevec_hnsw_check_value(internal) RETURNS internal
|
||||
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
|
||||
|
||||
-- sparsevec casts
|
||||
@@ -872,13 +875,15 @@ CREATE OPERATOR CLASS sparsevec_l2_ops
|
||||
FOR TYPE sparsevec USING hnsw AS
|
||||
OPERATOR 1 <-> (sparsevec, sparsevec) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 sparsevec_l2_squared_distance(sparsevec, sparsevec),
|
||||
FUNCTION 4 sparsevec_hnsw_support(internal);
|
||||
FUNCTION 4 sparsevec_hnsw_max_dims(internal),
|
||||
FUNCTION 5 sparsevec_hnsw_check_value(internal);
|
||||
|
||||
CREATE OPERATOR CLASS sparsevec_ip_ops
|
||||
FOR TYPE sparsevec USING hnsw AS
|
||||
OPERATOR 1 <#> (sparsevec, sparsevec) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec),
|
||||
FUNCTION 4 sparsevec_hnsw_support(internal);
|
||||
FUNCTION 4 sparsevec_hnsw_max_dims(internal),
|
||||
FUNCTION 5 sparsevec_hnsw_check_value(internal);
|
||||
|
||||
CREATE OPERATOR CLASS sparsevec_cosine_ops
|
||||
FOR TYPE sparsevec USING hnsw AS
|
||||
@@ -886,10 +891,12 @@ CREATE OPERATOR CLASS sparsevec_cosine_ops
|
||||
FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec),
|
||||
FUNCTION 2 l2_norm(sparsevec),
|
||||
FUNCTION 3 l2_normalize(sparsevec),
|
||||
FUNCTION 4 sparsevec_hnsw_support(internal);
|
||||
FUNCTION 4 sparsevec_hnsw_max_dims(internal),
|
||||
FUNCTION 5 sparsevec_hnsw_check_value(internal);
|
||||
|
||||
CREATE OPERATOR CLASS sparsevec_l1_ops
|
||||
FOR TYPE sparsevec USING hnsw AS
|
||||
OPERATOR 1 <+> (sparsevec, sparsevec) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 l1_distance(sparsevec, sparsevec),
|
||||
FUNCTION 4 sparsevec_hnsw_support(internal);
|
||||
FUNCTION 4 sparsevec_hnsw_max_dims(internal),
|
||||
FUNCTION 5 sparsevec_hnsw_check_value(internal);
|
||||
|
||||
@@ -194,7 +194,7 @@ hnswhandler(PG_FUNCTION_ARGS)
|
||||
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
|
||||
|
||||
amroutine->amstrategies = 0;
|
||||
amroutine->amsupport = 4;
|
||||
amroutine->amsupport = 5;
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
amroutine->amoptsprocnum = 0;
|
||||
#endif
|
||||
|
||||
17
src/hnsw.h
17
src/hnsw.h
@@ -23,7 +23,8 @@
|
||||
#define HNSW_DISTANCE_PROC 1
|
||||
#define HNSW_NORM_PROC 2
|
||||
#define HNSW_NORMALIZE_PROC 3
|
||||
#define HNSW_TYPE_SUPPORT_PROC 4
|
||||
#define HNSW_MAX_DIMS_PROC 4
|
||||
#define HNSW_CHECK_VALUE_PROC 5
|
||||
|
||||
#define HNSW_VERSION 1
|
||||
#define HNSW_MAGIC_NUMBER 0xA953A953
|
||||
@@ -58,15 +59,6 @@
|
||||
#define HNSW_UPDATE_ENTRY_GREATER 1
|
||||
#define HNSW_UPDATE_ENTRY_ALWAYS 2
|
||||
|
||||
typedef enum HnswType
|
||||
{
|
||||
HNSW_TYPE_VECTOR,
|
||||
HNSW_TYPE_HALFVEC,
|
||||
HNSW_TYPE_BIT,
|
||||
HNSW_TYPE_SPARSEVEC,
|
||||
HNSW_TYPE_UNSUPPORTED
|
||||
} HnswType;
|
||||
|
||||
/* Build phases */
|
||||
/* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */
|
||||
#define PROGRESS_HNSW_PHASE_LOAD 2
|
||||
@@ -254,7 +246,6 @@ typedef struct HnswBuildState
|
||||
Relation index;
|
||||
IndexInfo *indexInfo;
|
||||
ForkNumber forkNum;
|
||||
HnswType type;
|
||||
|
||||
/* Settings */
|
||||
int dimensions;
|
||||
@@ -269,6 +260,7 @@ typedef struct HnswBuildState
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
FmgrInfo *normalizeprocinfo;
|
||||
FmgrInfo *checkvalueprocinfo;
|
||||
Oid collation;
|
||||
|
||||
/* Variables */
|
||||
@@ -381,10 +373,9 @@ typedef struct HnswVacuumState
|
||||
int HnswGetM(Relation index);
|
||||
int HnswGetEfConstruction(Relation index);
|
||||
FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum);
|
||||
HnswType HnswGetType(Relation index);
|
||||
Datum HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum value);
|
||||
bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
|
||||
void HnswCheckValue(Datum value, HnswType type);
|
||||
void HnswCheckValue(FmgrInfo *procinfo, Oid collation, Datum value);
|
||||
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
|
||||
void HnswInitPage(Buffer buf, Page page);
|
||||
void HnswInit(void);
|
||||
|
||||
@@ -488,7 +488,8 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn
|
||||
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
|
||||
|
||||
/* Check value */
|
||||
HnswCheckValue(value, buildstate->type);
|
||||
if (buildstate->checkvalueprocinfo != NULL)
|
||||
HnswCheckValue(buildstate->checkvalueprocinfo, buildstate->collation, value);
|
||||
|
||||
/* Normalize if needed */
|
||||
if (buildstate->normprocinfo != NULL)
|
||||
@@ -675,18 +676,14 @@ HnswSharedMemoryAlloc(Size size, void *state)
|
||||
* Get max dimensions
|
||||
*/
|
||||
static int
|
||||
GetMaxDimensions(HnswType type)
|
||||
GetMaxDimensions(Relation index)
|
||||
{
|
||||
int maxDimensions = HNSW_MAX_DIM;
|
||||
FmgrInfo *procinfo = HnswOptionalProcInfo(index, HNSW_MAX_DIMS_PROC);
|
||||
|
||||
if (type == HNSW_TYPE_HALFVEC)
|
||||
maxDimensions *= 2;
|
||||
else if (type == HNSW_TYPE_BIT)
|
||||
maxDimensions *= 32;
|
||||
else if (type == HNSW_TYPE_SPARSEVEC)
|
||||
maxDimensions = INT_MAX;
|
||||
if (procinfo == NULL)
|
||||
return HNSW_MAX_DIM;
|
||||
|
||||
return maxDimensions;
|
||||
return DatumGetInt32(FunctionCall1(procinfo, PointerGetDatum(NULL)));
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -701,13 +698,16 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
|
||||
buildstate->index = index;
|
||||
buildstate->indexInfo = indexInfo;
|
||||
buildstate->forkNum = forkNum;
|
||||
buildstate->type = HnswGetType(index);
|
||||
|
||||
buildstate->m = HnswGetM(index);
|
||||
buildstate->efConstruction = HnswGetEfConstruction(index);
|
||||
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
|
||||
|
||||
maxDimensions = GetMaxDimensions(buildstate->type);
|
||||
/* Disallow varbit since require fixed dimensions */
|
||||
if (TupleDescAttr(index->rd_att, 0)->atttypid == VARBITOID)
|
||||
elog(ERROR, "type not supported for hnsw index");
|
||||
|
||||
maxDimensions = GetMaxDimensions(index);
|
||||
|
||||
/* Require column to have dimensions to be indexed */
|
||||
if (buildstate->dimensions < 0)
|
||||
@@ -726,6 +726,7 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
|
||||
buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
|
||||
buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
|
||||
buildstate->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC);
|
||||
buildstate->checkvalueprocinfo = HnswOptionalProcInfo(index, HNSW_CHECK_VALUE_PROC);
|
||||
buildstate->collation = index->rd_indcollation[0];
|
||||
|
||||
InitGraph(&buildstate->graphData, NULL, maintenance_work_mem * 1024L);
|
||||
|
||||
@@ -612,6 +612,7 @@ static void
|
||||
HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid)
|
||||
{
|
||||
Datum value;
|
||||
FmgrInfo *checkvalueprocinfo = HnswOptionalProcInfo(index, HNSW_CHECK_VALUE_PROC);
|
||||
FmgrInfo *normprocinfo;
|
||||
Oid collation = index->rd_indcollation[0];
|
||||
|
||||
@@ -619,7 +620,8 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti
|
||||
value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
|
||||
|
||||
/* Check value */
|
||||
HnswCheckValue(value, HnswGetType(index));
|
||||
if (checkvalueprocinfo != NULL)
|
||||
HnswCheckValue(checkvalueprocinfo, collation, value);
|
||||
|
||||
/* Normalize if needed */
|
||||
normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
|
||||
|
||||
@@ -152,27 +152,6 @@ HnswOptionalProcInfo(Relation index, uint16 procnum)
|
||||
return index_getprocinfo(index, 1, procnum);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get type
|
||||
*/
|
||||
HnswType
|
||||
HnswGetType(Relation index)
|
||||
{
|
||||
FmgrInfo *procinfo = HnswOptionalProcInfo(index, HNSW_TYPE_SUPPORT_PROC);
|
||||
Oid typid = TupleDescAttr(index->rd_att, 0)->atttypid;
|
||||
HnswType result;
|
||||
|
||||
if (procinfo == NULL)
|
||||
return HNSW_TYPE_VECTOR;
|
||||
|
||||
result = (HnswType) DatumGetInt32(FunctionCall1(procinfo, ObjectIdGetDatum(typid)));
|
||||
|
||||
if (result == HNSW_TYPE_UNSUPPORTED)
|
||||
elog(ERROR, "type not supported for hnsw index");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Normalize value
|
||||
*/
|
||||
@@ -198,15 +177,9 @@ HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value)
|
||||
* Check if a value can be indexed
|
||||
*/
|
||||
void
|
||||
HnswCheckValue(Datum value, HnswType type)
|
||||
HnswCheckValue(FmgrInfo *procinfo, Oid collation, Datum value)
|
||||
{
|
||||
if (type == HNSW_TYPE_SPARSEVEC)
|
||||
{
|
||||
SparseVector *vec = DatumGetSparseVector(value);
|
||||
|
||||
if (vec->nnz > HNSW_MAX_NNZ)
|
||||
elog(ERROR, "sparsevec cannot have more than %d non-zero elements for hnsw index", HNSW_MAX_NNZ);
|
||||
}
|
||||
FunctionCall1Coll(procinfo, collation, value);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -1303,28 +1276,35 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint
|
||||
}
|
||||
}
|
||||
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_hnsw_support);
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_hnsw_max_dims);
|
||||
Datum
|
||||
halfvec_hnsw_support(PG_FUNCTION_ARGS)
|
||||
halfvec_hnsw_max_dims(PG_FUNCTION_ARGS)
|
||||
{
|
||||
PG_RETURN_INT32(HNSW_TYPE_HALFVEC);
|
||||
PG_RETURN_INT32(HNSW_MAX_DIM * 2);
|
||||
};
|
||||
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(bit_hnsw_support);
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(bit_hnsw_max_dims);
|
||||
Datum
|
||||
bit_hnsw_support(PG_FUNCTION_ARGS)
|
||||
bit_hnsw_max_dims(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Oid typid = PG_GETARG_OID(0);
|
||||
|
||||
if (typid == BITOID)
|
||||
PG_RETURN_INT32(HNSW_TYPE_BIT);
|
||||
else
|
||||
PG_RETURN_INT32(HNSW_TYPE_UNSUPPORTED);
|
||||
PG_RETURN_INT32(HNSW_MAX_DIM * 32);
|
||||
};
|
||||
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_hnsw_support);
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_hnsw_max_dims);
|
||||
Datum
|
||||
sparsevec_hnsw_support(PG_FUNCTION_ARGS)
|
||||
sparsevec_hnsw_max_dims(PG_FUNCTION_ARGS)
|
||||
{
|
||||
PG_RETURN_INT32(HNSW_TYPE_SPARSEVEC);
|
||||
PG_RETURN_INT32(SPARSEVEC_MAX_DIM);
|
||||
};
|
||||
|
||||
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_hnsw_check_value);
|
||||
Datum
|
||||
sparsevec_hnsw_check_value(PG_FUNCTION_ARGS)
|
||||
{
|
||||
SparseVector *vec = PG_GETARG_SPARSEVEC_P(0);
|
||||
|
||||
if (vec->nnz > HNSW_MAX_NNZ)
|
||||
elog(ERROR, "sparsevec cannot have more than %d non-zero elements for hnsw index", HNSW_MAX_NNZ);
|
||||
|
||||
PG_RETURN_VOID();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user