Removed more vector-specific code from IVFFlat

This commit is contained in:
Andrew Kane
2024-04-11 13:59:20 -07:00
parent bd52ed29e0
commit 17c2f9c0b6
3 changed files with 17 additions and 28 deletions

View File

@@ -264,13 +264,12 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque;
#define VECTOR_ARRAY_SIZE(_length, _size) (sizeof(VectorArrayData) + (_length) * _size)
#define VECTOR_ARRAY_OFFSET(_arr, _offset) ((char*) (_arr)->items + (_offset) * (_arr)->itemsize)
#define VectorArrayGet(_arr, _offset) ((Vector *) VECTOR_ARRAY_OFFSET(_arr, _offset))
#define VectorArrayGet(_arr, _offset) VECTOR_ARRAY_OFFSET(_arr, _offset)
#define VectorArraySet(_arr, _offset, _val) memcpy(VECTOR_ARRAY_OFFSET(_arr, _offset), _val, (_arr)->itemsize)
/* Methods */
VectorArray VectorArrayInit(int maxlen, int dimensions, Size itemsize);
void VectorArrayFree(VectorArray arr);
void PrintVectorArray(char *msg, VectorArray arr);
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers);
FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum);
IvfflatType IvfflatGetType(Relation index);

View File

@@ -43,12 +43,12 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low
for (j = 0; j < numSamples; j++)
{
Vector *vec = VectorArrayGet(samples, j);
Datum vec = PointerGetDatum(VectorArrayGet(samples, j));
double distance;
/* Only need to compute distance for new center */
/* TODO Use triangle inequality to reduce distance calculations */
distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, i))));
distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, vec, PointerGetDatum(VectorArrayGet(centers, i))));
/* Set lower bound */
lowerBound[j * numCenters + i] = distance;
@@ -123,7 +123,7 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers)
qsort(samples->items, samples->length, samples->itemsize, CompareVectors);
for (int i = 0; i < samples->length; i++)
{
Vector *vec = VectorArrayGet(samples, i);
Vector *vec = (Vector *) VectorArrayGet(samples, i);
if (i == 0 || CompareVectors(vec, VectorArrayGet(samples, i - 1)) != 0)
{
@@ -136,7 +136,7 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers)
/* Fill remaining with random data */
while (centers->length < centers->maxlen)
{
Vector *vec = VectorArrayGet(centers, centers->length);
Vector *vec = (Vector *) VectorArrayGet(centers, centers->length);
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
vec->dim = dimensions;
@@ -248,7 +248,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
for (int64 j = 0; j < numCenters; j++)
{
Vector *vec = VectorArrayGet(newCenters, j);
Vector *vec = (Vector *) VectorArrayGet(newCenters, j);
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
vec->dim = dimensions;
@@ -296,11 +296,11 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
/* Step 1: For all centers, compute distance */
for (int64 j = 0; j < numCenters; j++)
{
Vector *vec = VectorArrayGet(centers, j);
Datum vec = PointerGetDatum(VectorArrayGet(centers, j));
for (int64 k = j + 1; k < numCenters; k++)
{
float distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k))));
float distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, vec, PointerGetDatum(VectorArrayGet(centers, k))));
halfcdist[j * numCenters + k] = distance;
halfcdist[k * numCenters + j] = distance;
@@ -341,7 +341,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
for (int64 k = 0; k < numCenters; k++)
{
Vector *vec;
Datum vec;
float dxcx;
/* Step 3: For all remaining points x and centers c */
@@ -354,12 +354,12 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
if (upperBound[j] <= halfcdist[closestCenters[j] * numCenters + k])
continue;
vec = VectorArrayGet(samples, j);
vec = PointerGetDatum(VectorArrayGet(samples, j));
/* Step 3a */
if (rj)
{
dxcx = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, closestCenters[j]))));
dxcx = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, vec, PointerGetDatum(VectorArrayGet(centers, closestCenters[j]))));
/* d(x,c(x)) computed, which is a form of d(x,c) */
lowerBound[j * numCenters + closestCenters[j]] = dxcx;
@@ -373,7 +373,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
/* Step 3b */
if (dxcx > lowerBound[j * numCenters + k] || dxcx > halfcdist[closestCenters[j] * numCenters + k])
{
float dxc = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k))));
float dxc = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, vec, PointerGetDatum(VectorArrayGet(centers, k))));
/* d(x,c) calculated */
lowerBound[j * numCenters + k] = dxc;
@@ -394,7 +394,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
/* Step 4: For each center c, let m(c) be mean of all points assigned */
for (int64 j = 0; j < numCenters; j++)
{
Vector *vec = VectorArrayGet(newCenters, j);
Vector *vec = (Vector *) VectorArrayGet(newCenters, j);
for (int64 k = 0; k < dimensions; k++)
vec->x[k] = 0.0;
@@ -405,8 +405,8 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
for (int64 j = 0; j < numSamples; j++)
{
int closestCenter = closestCenters[j];
Vector *vec = VectorArrayGet(samples, j);
Vector *newCenter = VectorArrayGet(newCenters, closestCenter);
Vector *vec = (Vector *) VectorArrayGet(samples, j);
Vector *newCenter = (Vector *) VectorArrayGet(newCenters, closestCenter);
/* Increment sum and count of closest center */
for (int64 k = 0; k < dimensions; k++)
@@ -417,7 +417,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
for (int64 j = 0; j < numCenters; j++)
{
Vector *vec = VectorArrayGet(newCenters, j);
Vector *vec = (Vector *) VectorArrayGet(newCenters, j);
if (centerCounts[j] > 0)
{
@@ -492,7 +492,7 @@ CheckCenters(Relation index, VectorArray centers)
/* Ensure no NaN or infinite values */
for (int i = 0; i < centers->length; i++)
{
Vector *vec = VectorArrayGet(centers, i);
Vector *vec = (Vector *) VectorArrayGet(centers, i);
for (int j = 0; j < vec->dim; j++)
{

View File

@@ -33,16 +33,6 @@ VectorArrayFree(VectorArray arr)
pfree(arr);
}
/*
* Print vector array - useful for debugging
*/
void
PrintVectorArray(char *msg, VectorArray arr)
{
for (int i = 0; i < arr->length; i++)
PrintVector(msg, VectorArrayGet(arr, i));
}
/*
* Get the number of lists in the index
*/