mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Use List for samples
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "miscadmin.h"
|
||||
#include "storage/bufmgr.h"
|
||||
#include "tcop/tcopprot.h"
|
||||
#include "utils/datum.h"
|
||||
#include "utils/memutils.h"
|
||||
|
||||
#if PG_VERSION_NUM >= 140000
|
||||
@@ -65,11 +66,18 @@
|
||||
static void
|
||||
AddSample(Datum *values, IvfflatBuildState * buildstate)
|
||||
{
|
||||
VectorArray samples = buildstate->samples;
|
||||
int targsamples = samples->maxlen;
|
||||
MemoryContext oldCtx;
|
||||
Datum value;
|
||||
int targsamples = buildstate->targsamples;
|
||||
|
||||
/* Use memory context since detoast can allocate */
|
||||
oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx);
|
||||
|
||||
/* Detoast once for all calls */
|
||||
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
|
||||
value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
|
||||
|
||||
/* Restore memory context */
|
||||
MemoryContextSwitchTo(oldCtx);
|
||||
|
||||
/*
|
||||
* Normalize with KMEANS_NORM_PROC since spherical distance function
|
||||
@@ -81,18 +89,23 @@ AddSample(Datum *values, IvfflatBuildState * buildstate)
|
||||
return;
|
||||
}
|
||||
|
||||
if (samples->length < targsamples)
|
||||
{
|
||||
VectorArraySet(samples, samples->length, DatumGetVector(value));
|
||||
samples->length++;
|
||||
}
|
||||
/* Copy datum */
|
||||
value = datumCopy(value, false, -1);
|
||||
|
||||
/* Reset memory context */
|
||||
MemoryContextReset(buildstate->tmpCtx);
|
||||
|
||||
if (list_length(buildstate->samples) < targsamples)
|
||||
buildstate->samples = lappend(buildstate->samples, DatumGetVector(value));
|
||||
else
|
||||
{
|
||||
if (buildstate->rowstoskip < 0)
|
||||
buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, samples->length, targsamples);
|
||||
buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, list_length(buildstate->samples), targsamples);
|
||||
|
||||
if (buildstate->rowstoskip <= 0)
|
||||
{
|
||||
ListCell *lc;
|
||||
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
int k = (int) (targsamples * sampler_random_fract(&buildstate->rstate.randstate));
|
||||
#else
|
||||
@@ -100,7 +113,8 @@ AddSample(Datum *values, IvfflatBuildState * buildstate)
|
||||
#endif
|
||||
|
||||
Assert(k >= 0 && k < targsamples);
|
||||
VectorArraySet(samples, k, DatumGetVector(value));
|
||||
lc = list_nth_cell(buildstate->samples, k);
|
||||
lfirst(lc) = DatumGetVector(value);
|
||||
}
|
||||
|
||||
buildstate->rowstoskip -= 1;
|
||||
@@ -115,21 +129,13 @@ SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
|
||||
bool *isnull, bool tupleIsAlive, void *state)
|
||||
{
|
||||
IvfflatBuildState *buildstate = (IvfflatBuildState *) state;
|
||||
MemoryContext oldCtx;
|
||||
|
||||
/* Skip nulls */
|
||||
if (isnull[0])
|
||||
return;
|
||||
|
||||
/* Use memory context since detoast can allocate */
|
||||
oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx);
|
||||
|
||||
/* Add sample */
|
||||
AddSample(values, state);
|
||||
|
||||
/* Reset memory context */
|
||||
MemoryContextSwitchTo(oldCtx);
|
||||
MemoryContextReset(buildstate->tmpCtx);
|
||||
AddSample(values, buildstate);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -138,7 +144,7 @@ SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
|
||||
static void
|
||||
SampleRows(IvfflatBuildState * buildstate)
|
||||
{
|
||||
int targsamples = buildstate->samples->maxlen;
|
||||
int targsamples = buildstate->targsamples;
|
||||
BlockNumber totalblocks = RelationGetNumberOfBlocks(buildstate->heap);
|
||||
|
||||
buildstate->rowstoskip = -1;
|
||||
@@ -449,12 +455,13 @@ ComputeCenters(IvfflatBuildState * buildstate)
|
||||
|
||||
/* Sample rows */
|
||||
/* TODO Ensure within maintenance_work_mem */
|
||||
buildstate->samples = VectorArrayInit(numSamples, buildstate->dimensions);
|
||||
buildstate->samples = NIL;
|
||||
buildstate->targsamples = numSamples;
|
||||
if (buildstate->heap != NULL)
|
||||
{
|
||||
SampleRows(buildstate);
|
||||
|
||||
if (buildstate->samples->length < buildstate->lists)
|
||||
if (list_length(buildstate->samples) < buildstate->lists)
|
||||
{
|
||||
ereport(NOTICE,
|
||||
(errmsg("ivfflat index created with little data"),
|
||||
@@ -467,7 +474,7 @@ ComputeCenters(IvfflatBuildState * buildstate)
|
||||
IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers));
|
||||
|
||||
/* Free samples before we allocate more memory */
|
||||
VectorArrayFree(buildstate->samples);
|
||||
list_free_deep(buildstate->samples);
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@@ -80,6 +80,10 @@
|
||||
#define RandomInt() random()
|
||||
#endif
|
||||
|
||||
#if PG_VERSION_NUM < 130000
|
||||
#define list_sort(list, cmp) list_qsort(list, cmp)
|
||||
#endif
|
||||
|
||||
/* Variables */
|
||||
extern int ivfflat_probes;
|
||||
|
||||
@@ -178,7 +182,8 @@ typedef struct IvfflatBuildState
|
||||
Oid collation;
|
||||
|
||||
/* Variables */
|
||||
VectorArray samples;
|
||||
List *samples;
|
||||
int targsamples;
|
||||
VectorArray centers;
|
||||
ListInfo *listInfo;
|
||||
Vector *normvec;
|
||||
@@ -274,7 +279,7 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque;
|
||||
VectorArray VectorArrayInit(int maxlen, int dimensions);
|
||||
void VectorArrayFree(VectorArray arr);
|
||||
void PrintVectorArray(char *msg, VectorArray arr);
|
||||
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers);
|
||||
void IvfflatKmeans(Relation index, List *samples, VectorArray centers);
|
||||
FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum);
|
||||
bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result);
|
||||
int IvfflatGetLists(Relation index);
|
||||
|
||||
@@ -12,20 +12,20 @@
|
||||
* https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf
|
||||
*/
|
||||
static void
|
||||
InitCenters(Relation index, VectorArray samples, VectorArray centers, float *lowerBound)
|
||||
InitCenters(Relation index, List *samples, VectorArray centers, float *lowerBound)
|
||||
{
|
||||
FmgrInfo *procinfo;
|
||||
Oid collation;
|
||||
int64 j;
|
||||
float *weight = palloc(samples->length * sizeof(float));
|
||||
float *weight = palloc(list_length(samples) * sizeof(float));
|
||||
int numCenters = centers->maxlen;
|
||||
int numSamples = samples->length;
|
||||
int numSamples = list_length(samples);
|
||||
|
||||
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
|
||||
collation = index->rd_indcollation[0];
|
||||
|
||||
/* Choose an initial center uniformly at random */
|
||||
VectorArraySet(centers, 0, VectorArrayGet(samples, RandomInt() % samples->length));
|
||||
VectorArraySet(centers, 0, list_nth(samples, RandomInt() % list_length(samples)));
|
||||
centers->length++;
|
||||
|
||||
for (j = 0; j < numSamples; j++)
|
||||
@@ -42,7 +42,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low
|
||||
|
||||
for (j = 0; j < numSamples; j++)
|
||||
{
|
||||
Vector *vec = VectorArrayGet(samples, j);
|
||||
Vector *vec = list_nth(samples, j);
|
||||
double distance;
|
||||
|
||||
/* Only need to compute distance for new center */
|
||||
@@ -74,7 +74,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low
|
||||
break;
|
||||
}
|
||||
|
||||
VectorArraySet(centers, i + 1, VectorArrayGet(samples, j));
|
||||
VectorArraySet(centers, i + 1, list_nth(samples, j));
|
||||
centers->length++;
|
||||
}
|
||||
|
||||
@@ -106,25 +106,41 @@ CompareVectors(const void *a, const void *b)
|
||||
return vector_cmp_internal((Vector *) a, (Vector *) b);
|
||||
}
|
||||
|
||||
/*
|
||||
* Compare list vectors
|
||||
*/
|
||||
static int
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
CompareListVectors(const ListCell *a, const ListCell *b)
|
||||
#else
|
||||
CompareListVectors(const void *a, const void *b)
|
||||
#endif
|
||||
{
|
||||
Vector *va = lfirst((ListCell *) a);
|
||||
Vector *vb = lfirst((ListCell *) b);
|
||||
|
||||
return CompareVectors(va, vb);
|
||||
}
|
||||
|
||||
/*
|
||||
* Quick approach if we have little data
|
||||
*/
|
||||
static void
|
||||
QuickCenters(Relation index, VectorArray samples, VectorArray centers)
|
||||
QuickCenters(Relation index, List *samples, VectorArray centers)
|
||||
{
|
||||
int dimensions = centers->dim;
|
||||
Oid collation = index->rd_indcollation[0];
|
||||
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
|
||||
|
||||
/* Copy existing vectors while avoiding duplicates */
|
||||
if (samples->length > 0)
|
||||
if (list_length(samples) > 0)
|
||||
{
|
||||
qsort(samples->items, samples->length, VECTOR_SIZE(samples->dim), CompareVectors);
|
||||
for (int i = 0; i < samples->length; i++)
|
||||
list_sort(samples, CompareListVectors);
|
||||
for (int i = 0; i < list_length(samples); i++)
|
||||
{
|
||||
Vector *vec = VectorArrayGet(samples, i);
|
||||
Vector *vec = list_nth(samples, i);
|
||||
|
||||
if (i == 0 || CompareVectors(vec, VectorArrayGet(samples, i - 1)) != 0)
|
||||
if (i == 0 || CompareVectors(vec, list_nth(samples, i - 1)) != 0)
|
||||
{
|
||||
VectorArraySet(centers, centers->length, vec);
|
||||
centers->length++;
|
||||
@@ -160,7 +176,7 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers)
|
||||
* https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf
|
||||
*/
|
||||
static void
|
||||
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
|
||||
ElkanKmeans(Relation index, List *samples, VectorArray centers)
|
||||
{
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
@@ -171,7 +187,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
|
||||
int64 k;
|
||||
int dimensions = centers->dim;
|
||||
int numCenters = centers->maxlen;
|
||||
int numSamples = samples->length;
|
||||
int numSamples = list_length(samples);
|
||||
VectorArray newCenters;
|
||||
int *centerCounts;
|
||||
int *closestCenters;
|
||||
@@ -182,7 +198,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
|
||||
float *newcdist;
|
||||
|
||||
/* Calculate allocation sizes */
|
||||
Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->dim);
|
||||
Size samplesSize = 0;
|
||||
Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->dim);
|
||||
Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, dimensions);
|
||||
Size centerCountsSize = sizeof(int) * numCenters;
|
||||
@@ -326,7 +342,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
|
||||
if (upperBound[j] <= halfcdist[closestCenters[j] * numCenters + k])
|
||||
continue;
|
||||
|
||||
vec = VectorArrayGet(samples, j);
|
||||
vec = list_nth(samples, j);
|
||||
|
||||
/* Step 3a */
|
||||
if (rj)
|
||||
@@ -377,7 +393,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
|
||||
{
|
||||
int closestCenter;
|
||||
|
||||
vec = VectorArrayGet(samples, j);
|
||||
vec = list_nth(samples, j);
|
||||
closestCenter = closestCenters[j];
|
||||
|
||||
/* Increment sum and count of closest center */
|
||||
@@ -514,9 +530,9 @@ CheckCenters(Relation index, VectorArray centers)
|
||||
* We use spherical k-means for inner product and cosine
|
||||
*/
|
||||
void
|
||||
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers)
|
||||
IvfflatKmeans(Relation index, List *samples, VectorArray centers)
|
||||
{
|
||||
if (samples->length <= centers->maxlen)
|
||||
if (list_length(samples) <= centers->maxlen)
|
||||
QuickCenters(index, samples, centers);
|
||||
else
|
||||
ElkanKmeans(index, samples, centers);
|
||||
|
||||
Reference in New Issue
Block a user