Mark type-specific code

This commit is contained in:
Andrew Kane
2024-03-29 14:01:48 -07:00
parent 7d63bb4b98
commit 2c48e3edc2
5 changed files with 32 additions and 10 deletions

View File

@@ -149,6 +149,15 @@ HnswOptionalProcInfo(Relation index, uint16 procnum)
return index_getprocinfo(index, 1, procnum);
}
/*
* Get vector type
*/
HnswType
HnswGetType(Relation index)
{
return HNSW_TYPE_VECTOR;
}
/*
* Divide by the norm
*
@@ -158,20 +167,25 @@ HnswOptionalProcInfo(Relation index, uint16 procnum)
* if it's different than the original value
*/
bool
HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value)
HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type)
{
double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
if (norm > 0)
{
/* TODO Remove vector-specific code */
Vector *v = DatumGetVector(*value);
Vector *result = InitVector(v->dim);
if (type == HNSW_TYPE_VECTOR)
{
Vector *v = DatumGetVector(*value);
Vector *result = InitVector(v->dim);
for (int i = 0; i < v->dim; i++)
result->x[i] = v->x[i] / norm;
for (int i = 0; i < v->dim; i++)
result->x[i] = v->x[i] / norm;
*value = PointerGetDatum(result);
*value = PointerGetDatum(result);
}
else
elog(ERROR, "Unsupported type");
return true;
}