mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Mark type-specific code
This commit is contained in:
@@ -55,6 +55,11 @@
|
||||
#define HNSW_UPDATE_ENTRY_GREATER 1
|
||||
#define HNSW_UPDATE_ENTRY_ALWAYS 2
|
||||
|
||||
typedef enum HnswType
|
||||
{
|
||||
HNSW_TYPE_VECTOR
|
||||
} HnswType;
|
||||
|
||||
/* Build phases */
|
||||
/* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */
|
||||
#define PROGRESS_HNSW_PHASE_LOAD 2
|
||||
@@ -242,6 +247,7 @@ typedef struct HnswBuildState
|
||||
Relation index;
|
||||
IndexInfo *indexInfo;
|
||||
ForkNumber forkNum;
|
||||
HnswType type;
|
||||
|
||||
/* Settings */
|
||||
int dimensions;
|
||||
@@ -366,7 +372,8 @@ typedef struct HnswVacuumState
|
||||
int HnswGetM(Relation index);
|
||||
int HnswGetEfConstruction(Relation index);
|
||||
FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum);
|
||||
bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value);
|
||||
HnswType HnswGetType(Relation index);
|
||||
bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type);
|
||||
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
|
||||
void HnswInitPage(Buffer buf, Page page);
|
||||
void HnswInit(void);
|
||||
|
||||
@@ -489,7 +489,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn
|
||||
/* Normalize if needed */
|
||||
if (buildstate->normprocinfo != NULL)
|
||||
{
|
||||
if (!HnswNormValue(buildstate->normprocinfo, buildstate->collation, &value))
|
||||
if (!HnswNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->type))
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -677,6 +677,7 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
|
||||
buildstate->index = index;
|
||||
buildstate->indexInfo = indexInfo;
|
||||
buildstate->forkNum = forkNum;
|
||||
buildstate->type = HnswGetType(index);
|
||||
|
||||
buildstate->m = HnswGetM(index);
|
||||
buildstate->efConstruction = HnswGetEfConstruction(index);
|
||||
|
||||
@@ -622,7 +622,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti
|
||||
normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
|
||||
if (normprocinfo != NULL)
|
||||
{
|
||||
if (!HnswNormValue(normprocinfo, collation, &value))
|
||||
if (!HnswNormValue(normprocinfo, collation, &value, HnswGetType(index)))
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ GetScanValue(IndexScanDesc scan)
|
||||
|
||||
/* Fine if normalization fails */
|
||||
if (so->normprocinfo != NULL)
|
||||
HnswNormValue(so->normprocinfo, so->collation, &value);
|
||||
HnswNormValue(so->normprocinfo, so->collation, &value, HnswGetType(scan->indexRelation));
|
||||
}
|
||||
|
||||
return value;
|
||||
|
||||
@@ -149,6 +149,15 @@ HnswOptionalProcInfo(Relation index, uint16 procnum)
|
||||
return index_getprocinfo(index, 1, procnum);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get vector type
|
||||
*/
|
||||
HnswType
|
||||
HnswGetType(Relation index)
|
||||
{
|
||||
return HNSW_TYPE_VECTOR;
|
||||
}
|
||||
|
||||
/*
|
||||
* Divide by the norm
|
||||
*
|
||||
@@ -158,20 +167,25 @@ HnswOptionalProcInfo(Relation index, uint16 procnum)
|
||||
* if it's different than the original value
|
||||
*/
|
||||
bool
|
||||
HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value)
|
||||
HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type)
|
||||
{
|
||||
double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
|
||||
|
||||
if (norm > 0)
|
||||
{
|
||||
/* TODO Remove vector-specific code */
|
||||
Vector *v = DatumGetVector(*value);
|
||||
Vector *result = InitVector(v->dim);
|
||||
if (type == HNSW_TYPE_VECTOR)
|
||||
{
|
||||
Vector *v = DatumGetVector(*value);
|
||||
Vector *result = InitVector(v->dim);
|
||||
|
||||
for (int i = 0; i < v->dim; i++)
|
||||
result->x[i] = v->x[i] / norm;
|
||||
for (int i = 0; i < v->dim; i++)
|
||||
result->x[i] = v->x[i] / norm;
|
||||
|
||||
*value = PointerGetDatum(result);
|
||||
*value = PointerGetDatum(result);
|
||||
}
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user