diff --git a/src/hnsw.h b/src/hnsw.h index 6f96cc3..901f22b 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -55,6 +55,11 @@ #define HNSW_UPDATE_ENTRY_GREATER 1 #define HNSW_UPDATE_ENTRY_ALWAYS 2 +typedef enum HnswType +{ + HNSW_TYPE_VECTOR +} HnswType; + /* Build phases */ /* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */ #define PROGRESS_HNSW_PHASE_LOAD 2 @@ -242,6 +247,7 @@ typedef struct HnswBuildState Relation index; IndexInfo *indexInfo; ForkNumber forkNum; + HnswType type; /* Settings */ int dimensions; @@ -366,7 +372,8 @@ typedef struct HnswVacuumState int HnswGetM(Relation index); int HnswGetEfConstruction(Relation index); FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum); -bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value); +HnswType HnswGetType(Relation index); +bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); diff --git a/src/hnswbuild.c b/src/hnswbuild.c index 7a89a61..cd7150d 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -489,7 +489,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn /* Normalize if needed */ if (buildstate->normprocinfo != NULL) { - if (!HnswNormValue(buildstate->normprocinfo, buildstate->collation, &value)) + if (!HnswNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->type)) return false; } @@ -677,6 +677,7 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index buildstate->index = index; buildstate->indexInfo = indexInfo; buildstate->forkNum = forkNum; + buildstate->type = HnswGetType(index); buildstate->m = HnswGetM(index); buildstate->efConstruction = HnswGetEfConstruction(index); diff --git a/src/hnswinsert.c b/src/hnswinsert.c index cf67518..0e09cfa 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -622,7 +622,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); if (normprocinfo != NULL) { - if (!HnswNormValue(normprocinfo, collation, &value)) + if (!HnswNormValue(normprocinfo, collation, &value, HnswGetType(index))) return; } diff --git a/src/hnswscan.c b/src/hnswscan.c index 365ac71..8dd4efd 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -61,7 +61,7 @@ GetScanValue(IndexScanDesc scan) /* Fine if normalization fails */ if (so->normprocinfo != NULL) - HnswNormValue(so->normprocinfo, so->collation, &value); + HnswNormValue(so->normprocinfo, so->collation, &value, HnswGetType(scan->indexRelation)); } return value; diff --git a/src/hnswutils.c b/src/hnswutils.c index 20a323f..0c828ae 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -149,6 +149,15 @@ HnswOptionalProcInfo(Relation index, uint16 procnum) return index_getprocinfo(index, 1, procnum); } +/* + * Get vector type + */ +HnswType +HnswGetType(Relation index) +{ + return HNSW_TYPE_VECTOR; +} + /* * Divide by the norm * @@ -158,20 +167,25 @@ HnswOptionalProcInfo(Relation index, uint16 procnum) * if it's different than the original value */ bool -HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value) +HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type) { double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); if (norm > 0) { /* TODO Remove vector-specific code */ - Vector *v = DatumGetVector(*value); - Vector *result = InitVector(v->dim); + if (type == HNSW_TYPE_VECTOR) + { + Vector *v = DatumGetVector(*value); + Vector *result = InitVector(v->dim); - for (int i = 0; i < v->dim; i++) - result->x[i] = v->x[i] / norm; + for (int i = 0; i < v->dim; i++) + result->x[i] = v->x[i] / norm; - *value = PointerGetDatum(result); + *value = PointerGetDatum(result); + } + else + elog(ERROR, "Unsupported type"); return true; }