Mark type-specific code

This commit is contained in:
Andrew Kane
2024-03-29 14:01:48 -07:00
parent 7d63bb4b98
commit 2c48e3edc2
5 changed files with 32 additions and 10 deletions

View File

@@ -55,6 +55,11 @@
#define HNSW_UPDATE_ENTRY_GREATER 1
#define HNSW_UPDATE_ENTRY_ALWAYS 2
typedef enum HnswType
{
HNSW_TYPE_VECTOR
} HnswType;
/* Build phases */
/* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */
#define PROGRESS_HNSW_PHASE_LOAD 2
@@ -242,6 +247,7 @@ typedef struct HnswBuildState
Relation index;
IndexInfo *indexInfo;
ForkNumber forkNum;
HnswType type;
/* Settings */
int dimensions;
@@ -366,7 +372,8 @@ typedef struct HnswVacuumState
int HnswGetM(Relation index);
int HnswGetEfConstruction(Relation index);
FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum);
bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value);
HnswType HnswGetType(Relation index);
bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type);
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
void HnswInitPage(Buffer buf, Page page);
void HnswInit(void);

View File

@@ -489,7 +489,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn
/* Normalize if needed */
if (buildstate->normprocinfo != NULL)
{
if (!HnswNormValue(buildstate->normprocinfo, buildstate->collation, &value))
if (!HnswNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->type))
return false;
}
@@ -677,6 +677,7 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
buildstate->index = index;
buildstate->indexInfo = indexInfo;
buildstate->forkNum = forkNum;
buildstate->type = HnswGetType(index);
buildstate->m = HnswGetM(index);
buildstate->efConstruction = HnswGetEfConstruction(index);

View File

@@ -622,7 +622,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti
normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
if (normprocinfo != NULL)
{
if (!HnswNormValue(normprocinfo, collation, &value))
if (!HnswNormValue(normprocinfo, collation, &value, HnswGetType(index)))
return;
}

View File

@@ -61,7 +61,7 @@ GetScanValue(IndexScanDesc scan)
/* Fine if normalization fails */
if (so->normprocinfo != NULL)
HnswNormValue(so->normprocinfo, so->collation, &value);
HnswNormValue(so->normprocinfo, so->collation, &value, HnswGetType(scan->indexRelation));
}
return value;

View File

@@ -149,6 +149,15 @@ HnswOptionalProcInfo(Relation index, uint16 procnum)
return index_getprocinfo(index, 1, procnum);
}
/*
* Get vector type
*/
HnswType
HnswGetType(Relation index)
{
return HNSW_TYPE_VECTOR;
}
/*
* Divide by the norm
*
@@ -158,20 +167,25 @@ HnswOptionalProcInfo(Relation index, uint16 procnum)
* if it's different than the original value
*/
bool
HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value)
HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type)
{
double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
if (norm > 0)
{
/* TODO Remove vector-specific code */
Vector *v = DatumGetVector(*value);
Vector *result = InitVector(v->dim);
if (type == HNSW_TYPE_VECTOR)
{
Vector *v = DatumGetVector(*value);
Vector *result = InitVector(v->dim);
for (int i = 0; i < v->dim; i++)
result->x[i] = v->x[i] / norm;
for (int i = 0; i < v->dim; i++)
result->x[i] = v->x[i] / norm;
*value = PointerGetDatum(result);
*value = PointerGetDatum(result);
}
else
elog(ERROR, "Unsupported type");
return true;
}