diff --git a/src/ivfbuild.c b/src/ivfbuild.c index d6cf049..fa7a17a 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -57,7 +57,7 @@ AddSample(Datum *values, IvfflatBuildState * buildstate) */ if (buildstate->kmeansnormprocinfo != NULL) { - if (!IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->collation, &value)) + if (!IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->collation, &value, buildstate->type)) return; } @@ -153,7 +153,7 @@ AddTupleToSort(Relation index, ItemPointer tid, Datum *values, IvfflatBuildState /* Normalize if needed */ if (buildstate->normprocinfo != NULL) { - if (!IvfflatNormValue(buildstate->normprocinfo, buildstate->collation, &value)) + if (!IvfflatNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->type)) return; } @@ -312,25 +312,39 @@ InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum) } } +/* + * Get max dimensions + */ +static int +GetMaxDimensions(IvfflatType type) +{ + return IVFFLAT_MAX_DIM; +} + /* * Initialize the build state */ static void InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo) { + int maxDimensions; + buildstate->heap = heap; buildstate->index = index; buildstate->indexInfo = indexInfo; + buildstate->type = IvfflatGetType(index); buildstate->lists = IvfflatGetLists(index); buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; + maxDimensions = GetMaxDimensions(buildstate->type); + /* Require column to have dimensions to be indexed */ if (buildstate->dimensions < 0) elog(ERROR, "column does not have dimensions"); - if (buildstate->dimensions > IVFFLAT_MAX_DIM) - elog(ERROR, "column cannot have more than %d dimensions for ivfflat index", IVFFLAT_MAX_DIM); + if (buildstate->dimensions > maxDimensions) + elog(ERROR, "column cannot have more than %d dimensions for ivfflat index", maxDimensions); buildstate->reltuples = 0; buildstate->indtuples = 0; diff --git a/src/ivfflat.h b/src/ivfflat.h index 422967f..149dc20 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -43,6 +43,11 @@ #define IVFFLAT_MAX_LISTS 32768 #define IVFFLAT_DEFAULT_PROBES 1 +typedef enum IvfflatType +{ + IVFFLAT_TYPE_VECTOR +} IvfflatType; + /* Build phases */ /* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */ #define PROGRESS_IVFFLAT_PHASE_KMEANS 2 @@ -153,6 +158,7 @@ typedef struct IvfflatBuildState Relation heap; Relation index; IndexInfo *indexInfo; + IvfflatType type; /* Settings */ int dimensions; @@ -266,7 +272,8 @@ void VectorArrayFree(VectorArray arr); void PrintVectorArray(char *msg, VectorArray arr); void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers); FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum); -bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value); +IvfflatType IvfflatGetType(Relation index); +bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, IvfflatType type); int IvfflatGetLists(Relation index); void IvfflatGetMetaPageInfo(Relation index, int *lists, int *dimensions); void IvfflatUpdateList(Relation index, ListInfo listInfo, BlockNumber insertPage, BlockNumber originalInsertPage, BlockNumber startPage, ForkNumber forkNum); diff --git a/src/ivfinsert.c b/src/ivfinsert.c index 2d8d4c3..6d24ecf 100644 --- a/src/ivfinsert.c +++ b/src/ivfinsert.c @@ -85,7 +85,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, R normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); if (normprocinfo != NULL) { - if (!IvfflatNormValue(normprocinfo, index->rd_indcollation[0], &value)) + if (!IvfflatNormValue(normprocinfo, index->rd_indcollation[0], &value, IvfflatGetType(index))) return; } diff --git a/src/ivfscan.c b/src/ivfscan.c index 66b3ae6..ae05f2d 100644 --- a/src/ivfscan.c +++ b/src/ivfscan.c @@ -268,6 +268,7 @@ ivfflatgettuple(IndexScanDesc scan, ScanDirection dir) if (so->first) { Datum value; + IvfflatType type = IvfflatGetType(scan->indexRelation); /* Count index scan for stats */ pgstat_count_index_scan(scan->indexRelation); @@ -282,7 +283,12 @@ ivfflatgettuple(IndexScanDesc scan, ScanDirection dir) elog(ERROR, "non-MVCC snapshots are not supported with ivfflat"); if (scan->orderByData->sk_flags & SK_ISNULL) - value = PointerGetDatum(InitVector(so->dimensions)); + { + if (type == IVFFLAT_TYPE_VECTOR) + value = PointerGetDatum(InitVector(so->dimensions)); + else + elog(ERROR, "Unsupported type"); + } else { value = scan->orderByData->sk_argument; @@ -293,7 +299,7 @@ ivfflatgettuple(IndexScanDesc scan, ScanDirection dir) /* Fine if normalization fails */ if (so->normprocinfo != NULL) - IvfflatNormValue(so->normprocinfo, so->collation, &value); + IvfflatNormValue(so->normprocinfo, so->collation, &value, type); } IvfflatBench("GetScanLists", GetScanLists(scan, value)); diff --git a/src/ivfutils.c b/src/ivfutils.c index 587a97e..02e3cbb 100644 --- a/src/ivfutils.c +++ b/src/ivfutils.c @@ -66,6 +66,15 @@ IvfflatOptionalProcInfo(Relation index, uint16 procnum) return index_getprocinfo(index, 1, procnum); } +/* + * Get type + */ +IvfflatType +IvfflatGetType(Relation index) +{ + return IVFFLAT_TYPE_VECTOR; +} + /* * Divide by the norm * @@ -75,19 +84,24 @@ IvfflatOptionalProcInfo(Relation index, uint16 procnum) * if it's different than the original value */ bool -IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value) +IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, IvfflatType type) { double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); if (norm > 0) { - Vector *v = DatumGetVector(*value); - Vector *result = InitVector(v->dim); + if (type == IVFFLAT_TYPE_VECTOR) + { + Vector *v = DatumGetVector(*value); + Vector *result = InitVector(v->dim); - for (int i = 0; i < v->dim; i++) - result->x[i] = v->x[i] / norm; + for (int i = 0; i < v->dim; i++) + result->x[i] = v->x[i] / norm; - *value = PointerGetDatum(result); + *value = PointerGetDatum(result); + } + else + elog(ERROR, "Unsupported type"); return true; }