From 64223989cd7c3b793dc1d932f309bcfaf1e8cee8 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 16 Oct 2023 15:32:51 -0700 Subject: [PATCH] Use List for samples --- src/ivfbuild.c | 53 +++++++++++++++++++++++++++--------------------- src/ivfflat.h | 9 +++++++-- src/ivfkmeans.c | 54 ++++++++++++++++++++++++++++++++----------------- 3 files changed, 72 insertions(+), 44 deletions(-) diff --git a/src/ivfbuild.c b/src/ivfbuild.c index 15ba9aa..b8f44c8 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -11,6 +11,7 @@ #include "miscadmin.h" #include "storage/bufmgr.h" #include "tcop/tcopprot.h" +#include "utils/datum.h" #include "utils/memutils.h" #if PG_VERSION_NUM >= 140000 @@ -65,11 +66,18 @@ static void AddSample(Datum *values, IvfflatBuildState * buildstate) { - VectorArray samples = buildstate->samples; - int targsamples = samples->maxlen; + MemoryContext oldCtx; + Datum value; + int targsamples = buildstate->targsamples; + + /* Use memory context since detoast can allocate */ + oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx); /* Detoast once for all calls */ - Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); + value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); + + /* Restore memory context */ + MemoryContextSwitchTo(oldCtx); /* * Normalize with KMEANS_NORM_PROC since spherical distance function @@ -81,18 +89,23 @@ AddSample(Datum *values, IvfflatBuildState * buildstate) return; } - if (samples->length < targsamples) - { - VectorArraySet(samples, samples->length, DatumGetVector(value)); - samples->length++; - } + /* Copy datum */ + value = datumCopy(value, false, -1); + + /* Reset memory context */ + MemoryContextReset(buildstate->tmpCtx); + + if (list_length(buildstate->samples) < targsamples) + buildstate->samples = lappend(buildstate->samples, DatumGetVector(value)); else { if (buildstate->rowstoskip < 0) - buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, samples->length, targsamples); + buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, list_length(buildstate->samples), targsamples); if (buildstate->rowstoskip <= 0) { + ListCell *lc; + #if PG_VERSION_NUM >= 150000 int k = (int) (targsamples * sampler_random_fract(&buildstate->rstate.randstate)); #else @@ -100,7 +113,8 @@ AddSample(Datum *values, IvfflatBuildState * buildstate) #endif Assert(k >= 0 && k < targsamples); - VectorArraySet(samples, k, DatumGetVector(value)); + lc = list_nth_cell(buildstate->samples, k); + lfirst(lc) = DatumGetVector(value); } buildstate->rowstoskip -= 1; @@ -115,21 +129,13 @@ SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, bool *isnull, bool tupleIsAlive, void *state) { IvfflatBuildState *buildstate = (IvfflatBuildState *) state; - MemoryContext oldCtx; /* Skip nulls */ if (isnull[0]) return; - /* Use memory context since detoast can allocate */ - oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx); - /* Add sample */ - AddSample(values, state); - - /* Reset memory context */ - MemoryContextSwitchTo(oldCtx); - MemoryContextReset(buildstate->tmpCtx); + AddSample(values, buildstate); } /* @@ -138,7 +144,7 @@ SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, static void SampleRows(IvfflatBuildState * buildstate) { - int targsamples = buildstate->samples->maxlen; + int targsamples = buildstate->targsamples; BlockNumber totalblocks = RelationGetNumberOfBlocks(buildstate->heap); buildstate->rowstoskip = -1; @@ -449,12 +455,13 @@ ComputeCenters(IvfflatBuildState * buildstate) /* Sample rows */ /* TODO Ensure within maintenance_work_mem */ - buildstate->samples = VectorArrayInit(numSamples, buildstate->dimensions); + buildstate->samples = NIL; + buildstate->targsamples = numSamples; if (buildstate->heap != NULL) { SampleRows(buildstate); - if (buildstate->samples->length < buildstate->lists) + if (list_length(buildstate->samples) < buildstate->lists) { ereport(NOTICE, (errmsg("ivfflat index created with little data"), @@ -467,7 +474,7 @@ ComputeCenters(IvfflatBuildState * buildstate) IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers)); /* Free samples before we allocate more memory */ - VectorArrayFree(buildstate->samples); + list_free_deep(buildstate->samples); } /* diff --git a/src/ivfflat.h b/src/ivfflat.h index 1eb35b0..a281700 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -80,6 +80,10 @@ #define RandomInt() random() #endif +#if PG_VERSION_NUM < 130000 +#define list_sort(list, cmp) list_qsort(list, cmp) +#endif + /* Variables */ extern int ivfflat_probes; @@ -178,7 +182,8 @@ typedef struct IvfflatBuildState Oid collation; /* Variables */ - VectorArray samples; + List *samples; + int targsamples; VectorArray centers; ListInfo *listInfo; Vector *normvec; @@ -274,7 +279,7 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque; VectorArray VectorArrayInit(int maxlen, int dimensions); void VectorArrayFree(VectorArray arr); void PrintVectorArray(char *msg, VectorArray arr); -void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers); +void IvfflatKmeans(Relation index, List *samples, VectorArray centers); FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum); bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result); int IvfflatGetLists(Relation index); diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index a87edcb..1bdf393 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -12,20 +12,20 @@ * https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf */ static void -InitCenters(Relation index, VectorArray samples, VectorArray centers, float *lowerBound) +InitCenters(Relation index, List *samples, VectorArray centers, float *lowerBound) { FmgrInfo *procinfo; Oid collation; int64 j; - float *weight = palloc(samples->length * sizeof(float)); + float *weight = palloc(list_length(samples) * sizeof(float)); int numCenters = centers->maxlen; - int numSamples = samples->length; + int numSamples = list_length(samples); procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC); collation = index->rd_indcollation[0]; /* Choose an initial center uniformly at random */ - VectorArraySet(centers, 0, VectorArrayGet(samples, RandomInt() % samples->length)); + VectorArraySet(centers, 0, list_nth(samples, RandomInt() % list_length(samples))); centers->length++; for (j = 0; j < numSamples; j++) @@ -42,7 +42,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low for (j = 0; j < numSamples; j++) { - Vector *vec = VectorArrayGet(samples, j); + Vector *vec = list_nth(samples, j); double distance; /* Only need to compute distance for new center */ @@ -74,7 +74,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low break; } - VectorArraySet(centers, i + 1, VectorArrayGet(samples, j)); + VectorArraySet(centers, i + 1, list_nth(samples, j)); centers->length++; } @@ -106,25 +106,41 @@ CompareVectors(const void *a, const void *b) return vector_cmp_internal((Vector *) a, (Vector *) b); } +/* + * Compare list vectors + */ +static int +#if PG_VERSION_NUM >= 130000 +CompareListVectors(const ListCell *a, const ListCell *b) +#else +CompareListVectors(const void *a, const void *b) +#endif +{ + Vector *va = lfirst((ListCell *) a); + Vector *vb = lfirst((ListCell *) b); + + return CompareVectors(va, vb); +} + /* * Quick approach if we have little data */ static void -QuickCenters(Relation index, VectorArray samples, VectorArray centers) +QuickCenters(Relation index, List *samples, VectorArray centers) { int dimensions = centers->dim; Oid collation = index->rd_indcollation[0]; FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); /* Copy existing vectors while avoiding duplicates */ - if (samples->length > 0) + if (list_length(samples) > 0) { - qsort(samples->items, samples->length, VECTOR_SIZE(samples->dim), CompareVectors); - for (int i = 0; i < samples->length; i++) + list_sort(samples, CompareListVectors); + for (int i = 0; i < list_length(samples); i++) { - Vector *vec = VectorArrayGet(samples, i); + Vector *vec = list_nth(samples, i); - if (i == 0 || CompareVectors(vec, VectorArrayGet(samples, i - 1)) != 0) + if (i == 0 || CompareVectors(vec, list_nth(samples, i - 1)) != 0) { VectorArraySet(centers, centers->length, vec); centers->length++; @@ -160,7 +176,7 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers) * https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf */ static void -ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) +ElkanKmeans(Relation index, List *samples, VectorArray centers) { FmgrInfo *procinfo; FmgrInfo *normprocinfo; @@ -171,7 +187,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) int64 k; int dimensions = centers->dim; int numCenters = centers->maxlen; - int numSamples = samples->length; + int numSamples = list_length(samples); VectorArray newCenters; int *centerCounts; int *closestCenters; @@ -182,7 +198,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) float *newcdist; /* Calculate allocation sizes */ - Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->dim); + Size samplesSize = 0; Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->dim); Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, dimensions); Size centerCountsSize = sizeof(int) * numCenters; @@ -326,7 +342,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) if (upperBound[j] <= halfcdist[closestCenters[j] * numCenters + k]) continue; - vec = VectorArrayGet(samples, j); + vec = list_nth(samples, j); /* Step 3a */ if (rj) @@ -377,7 +393,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) { int closestCenter; - vec = VectorArrayGet(samples, j); + vec = list_nth(samples, j); closestCenter = closestCenters[j]; /* Increment sum and count of closest center */ @@ -514,9 +530,9 @@ CheckCenters(Relation index, VectorArray centers) * We use spherical k-means for inner product and cosine */ void -IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers) +IvfflatKmeans(Relation index, List *samples, VectorArray centers) { - if (samples->length <= centers->maxlen) + if (list_length(samples) <= centers->maxlen) QuickCenters(index, samples, centers); else ElkanKmeans(index, samples, centers);