Use List for samples

This commit is contained in:
Andrew Kane
2023-10-16 15:32:51 -07:00
parent e630efd195
commit 64223989cd
3 changed files with 72 additions and 44 deletions

View File

@@ -11,6 +11,7 @@
#include "miscadmin.h"
#include "storage/bufmgr.h"
#include "tcop/tcopprot.h"
#include "utils/datum.h"
#include "utils/memutils.h"
#if PG_VERSION_NUM >= 140000
@@ -65,11 +66,18 @@
static void
AddSample(Datum *values, IvfflatBuildState * buildstate)
{
VectorArray samples = buildstate->samples;
int targsamples = samples->maxlen;
MemoryContext oldCtx;
Datum value;
int targsamples = buildstate->targsamples;
/* Use memory context since detoast can allocate */
oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx);
/* Detoast once for all calls */
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
/* Restore memory context */
MemoryContextSwitchTo(oldCtx);
/*
* Normalize with KMEANS_NORM_PROC since spherical distance function
@@ -81,18 +89,23 @@ AddSample(Datum *values, IvfflatBuildState * buildstate)
return;
}
if (samples->length < targsamples)
{
VectorArraySet(samples, samples->length, DatumGetVector(value));
samples->length++;
}
/* Copy datum */
value = datumCopy(value, false, -1);
/* Reset memory context */
MemoryContextReset(buildstate->tmpCtx);
if (list_length(buildstate->samples) < targsamples)
buildstate->samples = lappend(buildstate->samples, DatumGetVector(value));
else
{
if (buildstate->rowstoskip < 0)
buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, samples->length, targsamples);
buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, list_length(buildstate->samples), targsamples);
if (buildstate->rowstoskip <= 0)
{
ListCell *lc;
#if PG_VERSION_NUM >= 150000
int k = (int) (targsamples * sampler_random_fract(&buildstate->rstate.randstate));
#else
@@ -100,7 +113,8 @@ AddSample(Datum *values, IvfflatBuildState * buildstate)
#endif
Assert(k >= 0 && k < targsamples);
VectorArraySet(samples, k, DatumGetVector(value));
lc = list_nth_cell(buildstate->samples, k);
lfirst(lc) = DatumGetVector(value);
}
buildstate->rowstoskip -= 1;
@@ -115,21 +129,13 @@ SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
bool *isnull, bool tupleIsAlive, void *state)
{
IvfflatBuildState *buildstate = (IvfflatBuildState *) state;
MemoryContext oldCtx;
/* Skip nulls */
if (isnull[0])
return;
/* Use memory context since detoast can allocate */
oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx);
/* Add sample */
AddSample(values, state);
/* Reset memory context */
MemoryContextSwitchTo(oldCtx);
MemoryContextReset(buildstate->tmpCtx);
AddSample(values, buildstate);
}
/*
@@ -138,7 +144,7 @@ SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
static void
SampleRows(IvfflatBuildState * buildstate)
{
int targsamples = buildstate->samples->maxlen;
int targsamples = buildstate->targsamples;
BlockNumber totalblocks = RelationGetNumberOfBlocks(buildstate->heap);
buildstate->rowstoskip = -1;
@@ -449,12 +455,13 @@ ComputeCenters(IvfflatBuildState * buildstate)
/* Sample rows */
/* TODO Ensure within maintenance_work_mem */
buildstate->samples = VectorArrayInit(numSamples, buildstate->dimensions);
buildstate->samples = NIL;
buildstate->targsamples = numSamples;
if (buildstate->heap != NULL)
{
SampleRows(buildstate);
if (buildstate->samples->length < buildstate->lists)
if (list_length(buildstate->samples) < buildstate->lists)
{
ereport(NOTICE,
(errmsg("ivfflat index created with little data"),
@@ -467,7 +474,7 @@ ComputeCenters(IvfflatBuildState * buildstate)
IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers));
/* Free samples before we allocate more memory */
VectorArrayFree(buildstate->samples);
list_free_deep(buildstate->samples);
}
/*

View File

@@ -80,6 +80,10 @@
#define RandomInt() random()
#endif
#if PG_VERSION_NUM < 130000
#define list_sort(list, cmp) list_qsort(list, cmp)
#endif
/* Variables */
extern int ivfflat_probes;
@@ -178,7 +182,8 @@ typedef struct IvfflatBuildState
Oid collation;
/* Variables */
VectorArray samples;
List *samples;
int targsamples;
VectorArray centers;
ListInfo *listInfo;
Vector *normvec;
@@ -274,7 +279,7 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque;
VectorArray VectorArrayInit(int maxlen, int dimensions);
void VectorArrayFree(VectorArray arr);
void PrintVectorArray(char *msg, VectorArray arr);
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers);
void IvfflatKmeans(Relation index, List *samples, VectorArray centers);
FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum);
bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result);
int IvfflatGetLists(Relation index);

View File

@@ -12,20 +12,20 @@
* https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf
*/
static void
InitCenters(Relation index, VectorArray samples, VectorArray centers, float *lowerBound)
InitCenters(Relation index, List *samples, VectorArray centers, float *lowerBound)
{
FmgrInfo *procinfo;
Oid collation;
int64 j;
float *weight = palloc(samples->length * sizeof(float));
float *weight = palloc(list_length(samples) * sizeof(float));
int numCenters = centers->maxlen;
int numSamples = samples->length;
int numSamples = list_length(samples);
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
collation = index->rd_indcollation[0];
/* Choose an initial center uniformly at random */
VectorArraySet(centers, 0, VectorArrayGet(samples, RandomInt() % samples->length));
VectorArraySet(centers, 0, list_nth(samples, RandomInt() % list_length(samples)));
centers->length++;
for (j = 0; j < numSamples; j++)
@@ -42,7 +42,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low
for (j = 0; j < numSamples; j++)
{
Vector *vec = VectorArrayGet(samples, j);
Vector *vec = list_nth(samples, j);
double distance;
/* Only need to compute distance for new center */
@@ -74,7 +74,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low
break;
}
VectorArraySet(centers, i + 1, VectorArrayGet(samples, j));
VectorArraySet(centers, i + 1, list_nth(samples, j));
centers->length++;
}
@@ -106,25 +106,41 @@ CompareVectors(const void *a, const void *b)
return vector_cmp_internal((Vector *) a, (Vector *) b);
}
/*
* Compare list vectors
*/
static int
#if PG_VERSION_NUM >= 130000
CompareListVectors(const ListCell *a, const ListCell *b)
#else
CompareListVectors(const void *a, const void *b)
#endif
{
Vector *va = lfirst((ListCell *) a);
Vector *vb = lfirst((ListCell *) b);
return CompareVectors(va, vb);
}
/*
* Quick approach if we have little data
*/
static void
QuickCenters(Relation index, VectorArray samples, VectorArray centers)
QuickCenters(Relation index, List *samples, VectorArray centers)
{
int dimensions = centers->dim;
Oid collation = index->rd_indcollation[0];
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
/* Copy existing vectors while avoiding duplicates */
if (samples->length > 0)
if (list_length(samples) > 0)
{
qsort(samples->items, samples->length, VECTOR_SIZE(samples->dim), CompareVectors);
for (int i = 0; i < samples->length; i++)
list_sort(samples, CompareListVectors);
for (int i = 0; i < list_length(samples); i++)
{
Vector *vec = VectorArrayGet(samples, i);
Vector *vec = list_nth(samples, i);
if (i == 0 || CompareVectors(vec, VectorArrayGet(samples, i - 1)) != 0)
if (i == 0 || CompareVectors(vec, list_nth(samples, i - 1)) != 0)
{
VectorArraySet(centers, centers->length, vec);
centers->length++;
@@ -160,7 +176,7 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers)
* https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf
*/
static void
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
ElkanKmeans(Relation index, List *samples, VectorArray centers)
{
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
@@ -171,7 +187,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
int64 k;
int dimensions = centers->dim;
int numCenters = centers->maxlen;
int numSamples = samples->length;
int numSamples = list_length(samples);
VectorArray newCenters;
int *centerCounts;
int *closestCenters;
@@ -182,7 +198,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
float *newcdist;
/* Calculate allocation sizes */
Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->dim);
Size samplesSize = 0;
Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->dim);
Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, dimensions);
Size centerCountsSize = sizeof(int) * numCenters;
@@ -326,7 +342,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
if (upperBound[j] <= halfcdist[closestCenters[j] * numCenters + k])
continue;
vec = VectorArrayGet(samples, j);
vec = list_nth(samples, j);
/* Step 3a */
if (rj)
@@ -377,7 +393,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
{
int closestCenter;
vec = VectorArrayGet(samples, j);
vec = list_nth(samples, j);
closestCenter = closestCenters[j];
/* Increment sum and count of closest center */
@@ -514,9 +530,9 @@ CheckCenters(Relation index, VectorArray centers)
* We use spherical k-means for inner product and cosine
*/
void
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers)
IvfflatKmeans(Relation index, List *samples, VectorArray centers)
{
if (samples->length <= centers->maxlen)
if (list_length(samples) <= centers->maxlen)
QuickCenters(index, samples, centers);
else
ElkanKmeans(index, samples, centers);