diff --git a/src/ivfflat.h b/src/ivfflat.h index 5b619db..5393eb0 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -256,6 +256,7 @@ typedef struct IvfflatScanOpaqueData FmgrInfo *procinfo; FmgrInfo *normprocinfo; Oid collation; + Datum (*distfunc) (FmgrInfo *flinfo, Oid collation, Datum arg1, Datum arg2); /* Lists */ pairingheap *listQueue; diff --git a/src/ivfscan.c b/src/ivfscan.c index 42a05d4..018a636 100644 --- a/src/ivfscan.c +++ b/src/ivfscan.c @@ -3,10 +3,8 @@ #include #include "access/relscan.h" -#include "bitvec.h" #include "catalog/pg_operator_d.h" #include "catalog/pg_type_d.h" -#include "halfvec.h" #include "lib/pairingheap.h" #include "ivfflat.h" #include "miscadmin.h" @@ -58,7 +56,7 @@ GetScanLists(IndexScanDesc scan, Datum value) double distance; /* Use procinfo from the index instead of scan key for performance */ - distance = DatumGetFloat8(FunctionCall2Coll(so->procinfo, so->collation, PointerGetDatum(&list->center), value)); + distance = DatumGetFloat8(so->distfunc(so->procinfo, so->collation, PointerGetDatum(&list->center), value)); if (listCount < so->probes) { @@ -151,7 +149,7 @@ GetScanItems(IndexScanDesc scan, Datum value) * performance */ ExecClearTuple(slot); - slot->tts_values[0] = FunctionCall2Coll(so->procinfo, so->collation, datum, value); + slot->tts_values[0] = so->distfunc(so->procinfo, so->collation, datum, value); slot->tts_isnull[0] = false; slot->tts_values[1] = PointerGetDatum(&itup->t_tid); slot->tts_isnull[1] = false; @@ -179,6 +177,15 @@ GetScanItems(IndexScanDesc scan, Datum value) tuplesort_performsort(so->sortstate); } +/* + * Zero distance + */ +static Datum +ZeroDistance(FmgrInfo *flinfo, Oid collation, Datum arg1, Datum arg2) +{ + return Float8GetDatum(0.0); +} + /* * Get scan value */ @@ -190,20 +197,13 @@ GetScanValue(IndexScanDesc scan) if (scan->orderByData->sk_flags & SK_ISNULL) { - IvfflatType type = IvfflatGetType(scan->indexRelation); - - if (type == IVFFLAT_TYPE_VECTOR) - value = PointerGetDatum(InitVector(so->dimensions)); - else if (type == IVFFLAT_TYPE_HALFVEC) - value = PointerGetDatum(InitHalfVector(so->dimensions)); - else if (type == IVFFLAT_TYPE_BIT) - value = PointerGetDatum(InitBitVector(so->dimensions)); - else - elog(ERROR, "Unsupported type"); + value = PointerGetDatum(NULL); + so->distfunc = ZeroDistance; } else { value = scan->orderByData->sk_argument; + so->distfunc = FunctionCall2Coll; /* Value should not be compressed or toasted */ Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value)));