Added src directory

This commit is contained in:
Andrew Kane
2021-04-20 14:43:04 -07:00
parent 819ae41e86
commit a3d946f3bf
11 changed files with 1 additions and 1 deletions

503
src/ivfbuild.c Normal file
View File

@@ -0,0 +1,503 @@
#include "postgres.h"
#include <float.h>
#include "catalog/index.h"
#include "ivfflat.h"
#include "miscadmin.h"
#include "storage/bufmgr.h"
#if PG_VERSION_NUM >= 120000
#include "access/tableam.h"
#endif
#if PG_VERSION_NUM >= 110000
#include "catalog/pg_operator_d.h"
#include "catalog/pg_type_d.h"
#else
#include "catalog/pg_operator.h"
#include "catalog/pg_type.h"
#endif
#if PG_VERSION_NUM >= 130000
#define CALLBACK_ITEM_POINTER ItemPointer tid
#else
#define CALLBACK_ITEM_POINTER HeapTuple hup
#endif
/*
* Callback for sampling
*/
static void
SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
bool *isnull, bool tupleIsAlive, void *state)
{
IvfflatBuildState *buildstate = (IvfflatBuildState *) state;
VectorArray samples = buildstate->samples;
int targsamples = samples->maxlen;
Datum value = values[0];
/* Skip nulls */
if (isnull[0])
return;
/* Normalize the value */
if (buildstate->normprocinfo != NULL)
{
if (!IvfflatNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->normvec))
return;
}
if (samples->length < targsamples)
{
VectorArraySet(samples, samples->length, DatumGetVector(value));
samples->length++;
}
else
{
if (buildstate->rowstoskip < 0)
buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, samples->length, targsamples);
if (buildstate->rowstoskip <= 0)
{
int k = (int) (targsamples * sampler_random_fract(buildstate->rstate.randstate));
Assert(k >= 0 && k < targsamples);
VectorArraySet(samples, k, DatumGetVector(value));
}
buildstate->rowstoskip -= 1;
}
}
/*
* Sample rows with same logic as ANALYZE
*/
static void
SampleRows(IvfflatBuildState * buildstate)
{
int targsamples = buildstate->samples->maxlen;
BlockNumber totalblocks = RelationGetNumberOfBlocks(buildstate->heap);
buildstate->rowstoskip = -1;
BlockSampler_Init(&buildstate->bs, totalblocks, targsamples, random());
reservoir_init_selection_state(&buildstate->rstate, targsamples);
while (BlockSampler_HasMore(&buildstate->bs))
{
BlockNumber targblock = BlockSampler_Next(&buildstate->bs);
#if PG_VERSION_NUM >= 120000
table_index_build_range_scan(buildstate->heap, buildstate->index, buildstate->indexInfo,
false, true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL);
#elif PG_VERSION_NUM >= 110000
IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL);
#else
IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, true, targblock, 1, SampleCallback, (void *) buildstate);
#endif
}
}
/*
* Callback for table_index_build_scan
*/
static void
BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
bool *isnull, bool tupleIsAlive, void *state)
{
IvfflatBuildState *buildstate = (IvfflatBuildState *) state;
double distance;
double minDistance = DBL_MAX;
int closestCenter = -1;
VectorArray centers = buildstate->centers;
TupleTableSlot *slot = buildstate->slot;
Datum value = values[0];
int i;
#if PG_VERSION_NUM < 130000
ItemPointer tid = &hup->t_self;
#endif
if (isnull[0])
return;
/* Normalize if needed */
if (buildstate->normprocinfo != NULL)
{
if (!IvfflatNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->normvec))
return;
}
/* Find the list that minimizes the distance */
for (i = 0; i < centers->length; i++)
{
distance = DatumGetFloat8(FunctionCall2Coll(buildstate->procinfo, buildstate->collation, value, PointerGetDatum(VectorArrayGet(centers, i))));
if (distance < minDistance)
{
minDistance = distance;
closestCenter = i;
}
}
/* Create a virtual tuple */
ExecClearTuple(slot);
slot->tts_values[0] = Int32GetDatum(closestCenter);
slot->tts_isnull[0] = false;
slot->tts_values[1] = Int32GetDatum(ItemPointerGetBlockNumberNoCheck(tid));
slot->tts_isnull[1] = false;
slot->tts_values[2] = Int32GetDatum(ItemPointerGetOffsetNumberNoCheck(tid));
slot->tts_isnull[2] = false;
slot->tts_values[3] = value;
slot->tts_isnull[3] = false;
ExecStoreVirtualTuple(slot);
/*
* Add tuple to sort
*
* tuplesort_puttupleslot comment: Input data is always copied; the caller
* need not save it.
*/
tuplesort_puttupleslot(buildstate->sortstate, slot);
}
/*
* Get index tuple from sort state
*/
static inline void
GetNextTuple(Tuplesortstate *sortstate, TupleDesc tupdesc, TupleTableSlot *slot, IndexTuple *itup, int *list)
{
Datum value;
bool isnull;
int tupblk;
int tupoff;
#if PG_VERSION_NUM >= 100000
if (tuplesort_gettupleslot(sortstate, true, false, slot, NULL))
#else
if (tuplesort_gettupleslot(sortstate, true, slot, NULL))
#endif
{
*list = DatumGetInt32(slot_getattr(slot, 1, &isnull));
tupblk = DatumGetInt32(slot_getattr(slot, 2, &isnull));
tupoff = DatumGetInt32(slot_getattr(slot, 3, &isnull));
value = slot_getattr(slot, 4, &isnull);
/* Form the index tuple */
*itup = index_form_tuple(tupdesc, &value, &isnull);
ItemPointerSet(&(*itup)->t_tid, tupblk, tupoff);
}
else
*list = -1;
}
/*
* Create initial entry pages
*/
static void
InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum)
{
Buffer buf;
Page page;
GenericXLogState *state;
int list;
IndexTuple itup = NULL; /* silence compiler warning */
BlockNumber startPage = InvalidBlockNumber;
BlockNumber insertPage = InvalidBlockNumber;
Size itemsz;
int i;
#if PG_VERSION_NUM >= 120000
TupleTableSlot *slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsMinimalTuple);
#else
TupleTableSlot *slot = MakeSingleTupleTableSlot(buildstate->tupdesc);
#endif
TupleDesc tupdesc = RelationGetDescr(index);
GetNextTuple(buildstate->sortstate, tupdesc, slot, &itup, &list);
for (i = 0; i < buildstate->centers->length; i++)
{
buf = IvfflatNewBuffer(index, forkNum);
IvfflatInitPage(index, &buf, &page, &state);
startPage = BufferGetBlockNumber(buf);
/* Get all tuples for list */
while (list == i)
{
/* Check for free space */
itemsz = MAXALIGN(IndexTupleSize(itup));
if (PageGetFreeSpace(page) < itemsz)
IvfflatAppendPage(index, &buf, &page, &state, forkNum);
/* Add the item */
if (PageAddItem(page, (Item) itup, itemsz, InvalidOffsetNumber, false, false) == InvalidOffsetNumber)
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
buildstate->indtuples += 1;
GetNextTuple(buildstate->sortstate, tupdesc, slot, &itup, &list);
}
insertPage = BufferGetBlockNumber(buf);
IvfflatCommitBuffer(buf, state);
/* Set the start and insert pages */
IvfflatUpdateList(index, state, buildstate->listInfo[i], insertPage, startPage, forkNum);
}
}
/*
* Initialize the build state
*/
static void
InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo)
{
buildstate->heap = heap;
buildstate->index = index;
buildstate->indexInfo = indexInfo;
buildstate->lists = IvfflatGetLists(index);
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
/* Require column to have dimensions to be indexed */
if (buildstate->dimensions < 0)
elog(ERROR, "column does not have dimensions");
buildstate->reltuples = 0;
buildstate->indtuples = 0;
/* Get support functions */
buildstate->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
buildstate->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
buildstate->collation = index->rd_indcollation[0];
/* Create tuple description for sorting */
#if PG_VERSION_NUM >= 120000
buildstate->tupdesc = CreateTemplateTupleDesc(4);
#else
buildstate->tupdesc = CreateTemplateTupleDesc(4, false);
#endif
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 1, "list", INT4OID, -1, 0);
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 2, "blkno", INT4OID, -1, 0);
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 3, "offset", INT4OID, -1, 0);
#if PG_VERSION_NUM >= 110000
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 4, "vector", RelationGetDescr(index)->attrs[0].atttypid, -1, 0);
#else
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 4, "vector", RelationGetDescr(index)->attrs[0]->atttypid, -1, 0);
#endif
#if PG_VERSION_NUM >= 120000
buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual);
#else
buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc);
#endif
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions);
buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists);
/* Reuse for each tuple */
buildstate->normvec = InitVector(buildstate->dimensions);
}
/*
* Free resources
*/
static void
FreeBuildState(IvfflatBuildState * buildstate)
{
pfree(buildstate->centers);
pfree(buildstate->listInfo);
pfree(buildstate->normvec);
}
/*
* Compute centers
*/
static void
ComputeCenters(IvfflatBuildState * buildstate)
{
int numSamples;
/* Target 50 samples per list, with at least 10000 samples */
/* The number of samples has a large effect on index build time */
numSamples = buildstate->lists * 50;
if (numSamples < 10000)
numSamples = 10000;
/* Sample samples */
buildstate->samples = VectorArrayInit(numSamples, buildstate->dimensions);
if (buildstate->heap != NULL)
SampleRows(buildstate);
/* Calculate centers */
IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers);
/* Free samples before we allocate more memory */
pfree(buildstate->samples);
}
/*
* Create the metapage
*/
static void
CreateMetaPage(Relation index, int dimensions, int lists, ForkNumber forkNum)
{
Buffer buf;
Page page;
GenericXLogState *state;
IvfflatMetaPage metap;
buf = IvfflatNewBuffer(index, forkNum);
IvfflatInitPage(index, &buf, &page, &state);
/* Set metapage data */
metap = IvfflatPageGetMeta(page);
metap->magicNumber = IVFFLAT_MAGIC_NUMBER;
metap->version = IVFFLAT_VERSION;
metap->dimensions = dimensions;
metap->lists = lists;
((PageHeader) page)->pd_lower =
((char *) metap + sizeof(IvfflatMetaPageData)) - (char *) page;
IvfflatCommitBuffer(buf, state);
}
/*
* Create list pages
*/
static void
CreateListPages(Relation index, VectorArray centers, int dimensions,
int lists, ForkNumber forkNum, ListInfo * *listInfo)
{
int i;
Buffer buf;
Page page;
GenericXLogState *state;
OffsetNumber offno;
Size itemsz;
IvfflatList list;
itemsz = MAXALIGN(IVFFLAT_LIST_SIZE(dimensions));
list = palloc(itemsz);
buf = IvfflatNewBuffer(index, forkNum);
IvfflatInitPage(index, &buf, &page, &state);
for (i = 0; i < lists; i++)
{
/* Load list */
list->startPage = InvalidBlockNumber;
list->insertPage = InvalidBlockNumber;
memcpy(&list->center, VectorArrayGet(centers, i), VECTOR_SIZE(dimensions));
/* Ensure free space */
if (PageGetFreeSpace(page) < itemsz)
IvfflatAppendPage(index, &buf, &page, &state, forkNum);
/* Add the item */
offno = PageAddItem(page, (Item) list, itemsz, InvalidOffsetNumber, false, false);
if (offno == InvalidOffsetNumber)
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
/* Save location info */
(*listInfo)[i].blkno = BufferGetBlockNumber(buf);
(*listInfo)[i].offno = offno;
}
IvfflatCommitBuffer(buf, state);
pfree(list);
}
/*
* Create entry pages
*/
static void
CreateEntryPages(IvfflatBuildState * buildstate, ForkNumber forkNum)
{
AttrNumber attNums[] = {1};
Oid sortOperators[] = {Float8LessOperator};
Oid sortCollations[] = {InvalidOid};
bool nullsFirstFlags[] = {false};
#if PG_VERSION_NUM >= 110000
buildstate->sortstate = tuplesort_begin_heap(buildstate->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, maintenance_work_mem, NULL, false);
#else
buildstate->sortstate = tuplesort_begin_heap(buildstate->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, maintenance_work_mem, false);
#endif
/* Add tuples to sort */
if (buildstate->heap != NULL)
{
#if PG_VERSION_NUM >= 120000
buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, true, BuildCallback, (void *) buildstate, NULL);
#elif PG_VERSION_NUM >= 110000
buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, BuildCallback, (void *) buildstate, NULL);
#else
buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, BuildCallback, (void *) buildstate);
#endif
}
/* Sort and insert */
tuplesort_performsort(buildstate->sortstate);
InsertTuples(buildstate->index, buildstate, forkNum);
tuplesort_end(buildstate->sortstate);
}
/*
* Build the index
*/
static void
BuildIndex(Relation heap, Relation index, IndexInfo *indexInfo,
IvfflatBuildState * buildstate, ForkNumber forkNum)
{
InitBuildState(buildstate, heap, index, indexInfo);
ComputeCenters(buildstate);
/* Create pages */
CreateMetaPage(index, buildstate->dimensions, buildstate->lists, forkNum);
CreateListPages(index, buildstate->centers, buildstate->dimensions, buildstate->lists, forkNum, &buildstate->listInfo);
CreateEntryPages(buildstate, forkNum);
FreeBuildState(buildstate);
}
/*
* Build the index for a logged table
*/
IndexBuildResult *
ivfflatbuild(Relation heap, Relation index, IndexInfo *indexInfo)
{
IndexBuildResult *result;
IvfflatBuildState buildstate;
BuildIndex(heap, index, indexInfo, &buildstate, MAIN_FORKNUM);
result = (IndexBuildResult *) palloc(sizeof(IndexBuildResult));
result->heap_tuples = buildstate.reltuples;
result->index_tuples = buildstate.indtuples;
return result;
}
/*
* Build the index for an unlogged table
*/
void
ivfflatbuildempty(Relation index)
{
IndexInfo *indexInfo = BuildIndexInfo(index);
IvfflatBuildState buildstate;
BuildIndex(NULL, index, indexInfo, &buildstate, INIT_FORKNUM);
}

168
src/ivfflat.c Normal file
View File

@@ -0,0 +1,168 @@
#include "postgres.h"
#include "access/amapi.h"
#include "commands/vacuum.h"
#include "ivfflat.h"
#include "utils/guc.h"
#include "utils/selfuncs.h"
static relopt_kind ivfflat_relopt_kind;
/*
* Initialize index options and variables
*/
void
_PG_init(void)
{
ivfflat_relopt_kind = add_reloption_kind();
add_int_reloption(ivfflat_relopt_kind, "lists", "Number of inverted lists",
IVFFLAT_DEFAULT_LISTS, 1, IVFFLAT_MAX_LISTS
#if PG_VERSION_NUM >= 130000
,AccessExclusiveLock
#endif
);
DefineCustomIntVariable("ivfflat.probes", "Sets the number of probes",
"Valid range is 1..lists.", &ivfflat_probes,
1, 1, IVFFLAT_MAX_LISTS, PGC_USERSET, 0, NULL, NULL, NULL);
}
/*
* Estimate the cost of an index scan
*
* TODO Improve cost estimation
*/
static void
ivfflatcostestimate(PlannerInfo *root, IndexPath *path, double loop_count,
Cost *indexStartupCost, Cost *indexTotalCost,
Selectivity *indexSelectivity, double *indexCorrelation
#if PG_VERSION_NUM >= 100000
,double *indexPages
#endif
)
{
GenericCosts costs;
#if PG_VERSION_NUM < 120000
List *qinfos;
qinfos = deconstruct_indexquals(path);
#endif
MemSet(&costs, 0, sizeof(costs));
#if PG_VERSION_NUM >= 120000
genericcostestimate(root, path, loop_count, &costs);
#else
genericcostestimate(root, path, loop_count, qinfos, &costs);
#endif
*indexStartupCost = costs.indexStartupCost;
*indexTotalCost = costs.indexTotalCost;
*indexSelectivity = costs.indexSelectivity;
*indexCorrelation = costs.indexCorrelation;
#if PG_VERSION_NUM >= 100000
*indexPages = costs.numIndexPages;
#endif
}
/*
* Parse and validate the reloptions
*/
static bytea *
ivfflatoptions(Datum reloptions, bool validate)
{
static const relopt_parse_elt tab[] = {
{"lists", RELOPT_TYPE_INT, offsetof(IvfflatOptions, lists)},
};
#if PG_VERSION_NUM >= 130000
return (bytea *) build_reloptions(reloptions, validate,
ivfflat_relopt_kind,
sizeof(IvfflatOptions),
tab, lengthof(tab));
#else
relopt_value *options;
int numoptions;
IvfflatOptions *rdopts;
options = parseRelOptions(reloptions, validate, ivfflat_relopt_kind, &numoptions);
rdopts = allocateReloptStruct(sizeof(IvfflatOptions), options, numoptions);
fillRelOptions((void *) rdopts, sizeof(IvfflatOptions), options, numoptions,
validate, tab, lengthof(tab));
return (bytea *) rdopts;
#endif
}
/*
* Validate catalog entries for the specified operator class
*/
static bool
ivfflatvalidate(Oid opclassoid)
{
return true;
}
PG_FUNCTION_INFO_V1(ivfflathandler);
Datum
ivfflathandler(PG_FUNCTION_ARGS)
{
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
amroutine->amstrategies = 0;
amroutine->amsupport = 4;
#if PG_VERSION_NUM >= 130000
amroutine->amoptsprocnum = 0;
#endif
amroutine->amcanorder = false;
amroutine->amcanorderbyop = true;
amroutine->amcanbackward = false; /* can change direction mid-scan */
amroutine->amcanunique = false;
amroutine->amcanmulticol = false;
amroutine->amoptionalkey = true;
amroutine->amsearcharray = false;
amroutine->amsearchnulls = false;
amroutine->amstorage = false;
amroutine->amclusterable = false;
amroutine->ampredlocks = false;
#if PG_VERSION_NUM >= 100000
amroutine->amcanparallel = false;
#endif
#if PG_VERSION_NUM >= 110000
amroutine->amcaninclude = false;
#endif
#if PG_VERSION_NUM >= 130000
amroutine->amusemaintenanceworkmem = false; /* not used during VACUUM */
amroutine->amparallelvacuumoptions = VACUUM_OPTION_NO_PARALLEL; /* TODO support parallel */
#endif
amroutine->amkeytype = InvalidOid;
amroutine->ambuild = ivfflatbuild;
amroutine->ambuildempty = ivfflatbuildempty;
amroutine->aminsert = ivfflatinsert;
amroutine->ambulkdelete = ivfflatbulkdelete;
amroutine->amvacuumcleanup = ivfflatvacuumcleanup;
amroutine->amcanreturn = NULL;
amroutine->amcostestimate = ivfflatcostestimate;
amroutine->amoptions = ivfflatoptions;
amroutine->amproperty = NULL; /* TODO AMPROP_DISTANCE_ORDERABLE */
#if PG_VERSION_NUM >= 120000
amroutine->ambuildphasename = NULL;
#endif
amroutine->amvalidate = ivfflatvalidate;
amroutine->ambeginscan = ivfflatbeginscan;
amroutine->amrescan = ivfflatrescan;
amroutine->amgettuple = ivfflatgettuple;
amroutine->amgetbitmap = NULL;
amroutine->amendscan = ivfflatendscan;
amroutine->ammarkpos = NULL;
amroutine->amrestrpos = NULL;
#if PG_VERSION_NUM >= 100000
amroutine->amestimateparallelscan = NULL;
amroutine->aminitparallelscan = NULL;
amroutine->amparallelrescan = NULL;
#endif
PG_RETURN_POINTER(amroutine);
}

193
src/ivfflat.h Normal file
View File

@@ -0,0 +1,193 @@
#ifndef IVFFLAT_H
#define IVFFLAT_H
#include "postgres.h"
#include "access/generic_xlog.h"
#include "access/reloptions.h"
#include "nodes/execnodes.h"
#include "utils/sampling.h"
#include "utils/tuplesort.h"
#include "vector.h"
/* Support functions */
#define IVFFLAT_DISTANCE_PROC 1
#define IVFFLAT_NORM_PROC 2
#define IVFFLAT_KMEANS_DISTANCE_PROC 3
#define IVFFLAT_KMEANS_NORM_PROC 4
#define IVFFLAT_VERSION 1
#define IVFFLAT_MAGIC_NUMBER 0x14FF1A7
#define IVFFLAT_PAGE_ID 0xFF84
/* Preserved page numbers */
#define IVFFLAT_METAPAGE_BLKNO 0
#define IVFFLAT_HEAD_BLKNO 1 /* first list page */
#define IVFFLAT_DEFAULT_LISTS 100
#define IVFFLAT_MAX_LISTS 32768
#define IVFFLAT_LIST_SIZE(_dim) (offsetof(IvfflatListData, center) + VECTOR_SIZE(_dim))
#define IvfflatPageGetOpaque(page) ((IvfflatPageOpaque) PageGetSpecialPointer(page))
#define IvfflatPageGetMeta(page) ((IvfflatMetaPageData *) PageGetContents(page))
#if PG_VERSION_NUM < 100000
#define ItemPointerGetBlockNumberNoCheck ItemPointerGetBlockNumber
#define ItemPointerGetOffsetNumberNoCheck ItemPointerGetOffsetNumber
#endif
/* Variables */
int ivfflat_probes;
typedef struct VectorArrayData
{
int length;
int maxlen;
int dim;
Vector items[FLEXIBLE_ARRAY_MEMBER];
} VectorArrayData;
typedef VectorArrayData * VectorArray;
typedef struct ListInfo
{
BlockNumber blkno;
OffsetNumber offno;
} ListInfo;
/* IVFFlat index options */
typedef struct IvfflatOptions
{
int32 vl_len_; /* varlena header (do not touch directly!) */
int lists; /* number of lists */
} IvfflatOptions;
typedef struct IvfflatBuildState
{
/* Info */
Relation heap;
Relation index;
IndexInfo *indexInfo;
/* Settings */
int dimensions;
int lists;
/* Statistics */
double indtuples;
double reltuples;
/* Support functions */
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
Oid collation;
/* Variables */
VectorArray samples;
VectorArray centers;
ListInfo *listInfo;
Vector *normvec;
/* Sampling */
BlockSamplerData bs;
ReservoirStateData rstate;
int rowstoskip;
/* Sorting */
Tuplesortstate *sortstate;
TupleDesc tupdesc;
TupleTableSlot *slot;
} IvfflatBuildState;
typedef struct IvfflatMetaPageData
{
uint32 magicNumber;
uint32 version;
uint16 dimensions;
uint16 lists;
} IvfflatMetaPageData;
typedef IvfflatMetaPageData * IvfflatMetaPage;
typedef struct IvfflatPageOpaqueData
{
BlockNumber nextblkno;
uint16 unused;
uint16 page_id; /* for identification of IVFFlat indexes */
} IvfflatPageOpaqueData;
typedef IvfflatPageOpaqueData * IvfflatPageOpaque;
typedef struct IvfflatListData
{
BlockNumber startPage;
BlockNumber insertPage;
Vector center;
} IvfflatListData;
typedef IvfflatListData * IvfflatList;
typedef struct IvfflatScanList
{
BlockNumber startPage;
double distance;
} IvfflatScanList;
typedef struct IvfflatScanOpaqueData
{
int probes;
bool first;
Buffer buf;
/* Sorting */
Tuplesortstate *sortstate;
TupleDesc tupdesc;
TupleTableSlot *slot;
bool isnull;
/* Support functions */
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
Oid collation;
IvfflatScanList lists[FLEXIBLE_ARRAY_MEMBER]; /* must come last */
} IvfflatScanOpaqueData;
typedef IvfflatScanOpaqueData * IvfflatScanOpaque;
#define VECTOR_ARRAY_SIZE(_length, _dim) (offsetof(VectorArrayData, items) + _length * VECTOR_SIZE(_dim))
#define VECTOR_ARRAY_OFFSET(_arr, _offset) ((char*) _arr + offsetof(VectorArrayData, items) + (_offset) * VECTOR_SIZE(_arr->dim))
#define VectorArrayGet(_arr, _offset) ((Vector *) VECTOR_ARRAY_OFFSET(_arr, _offset))
#define VectorArraySet(_arr, _offset, _val) (memcpy(VECTOR_ARRAY_OFFSET(_arr, _offset), _val, VECTOR_SIZE(_arr->dim)))
/* Methods */
void _PG_init(void);
VectorArray VectorArrayInit(int maxlen, int dimensions);
void PrintVectorArray(char *msg, VectorArray arr);
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers);
FmgrInfo *IvfflatOptionalProcInfo(Relation rel, uint16 procnum);
bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result);
int IvfflatGetLists(Relation index);
void IvfflatUpdateList(Relation index, GenericXLogState *state, ListInfo listInfo, BlockNumber insertPage, BlockNumber startPage, ForkNumber forkNum);
void IvfflatCommitBuffer(Buffer buf, GenericXLogState *state);
void IvfflatAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, ForkNumber forkNum);
Buffer IvfflatNewBuffer(Relation index, ForkNumber forkNum);
void IvfflatInitPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state);
/* Index access methods */
IndexBuildResult *ivfflatbuild(Relation heap, Relation index, IndexInfo *indexInfo);
void ivfflatbuildempty(Relation index);
bool ivfflatinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heap, IndexUniqueCheck checkUnique
#if PG_VERSION_NUM >= 100000
,IndexInfo *indexInfo
#endif
);
IndexBulkDeleteResult *ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state);
IndexBulkDeleteResult *ivfflatvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats);
IndexScanDesc ivfflatbeginscan(Relation index, int nkeys, int norderbys);
void ivfflatrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys);
bool ivfflatgettuple(IndexScanDesc scan, ScanDirection dir);
void ivfflatendscan(IndexScanDesc scan);
#endif

163
src/ivfinsert.c Normal file
View File

@@ -0,0 +1,163 @@
#include "postgres.h"
#include <float.h>
#include "ivfflat.h"
#include "storage/bufmgr.h"
/*
* Find the list that minimizes the distance function
*/
static void
FindInsertPage(Relation rel, Datum *values, BlockNumber *insertPage, ListInfo * listInfo)
{
Buffer cbuf;
Page cpage;
IvfflatList list;
double distance;
double minDistance = DBL_MAX;
BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO;
FmgrInfo *procinfo;
Oid collation;
OffsetNumber offno;
OffsetNumber maxoffno;
procinfo = index_getprocinfo(rel, 1, IVFFLAT_DISTANCE_PROC);
collation = rel->rd_indcollation[0];
/* Search all list pages */
while (BlockNumberIsValid(nextblkno))
{
cbuf = ReadBuffer(rel, nextblkno);
LockBuffer(cbuf, BUFFER_LOCK_SHARE);
cpage = BufferGetPage(cbuf);
maxoffno = PageGetMaxOffsetNumber(cpage);
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, offno));
distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, values[0], PointerGetDatum(&list->center)));
if (distance < minDistance)
{
*insertPage = list->insertPage;
listInfo->blkno = nextblkno;
listInfo->offno = offno;
minDistance = distance;
}
}
nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno;
UnlockReleaseBuffer(cbuf);
}
}
/*
* Prepare to insert an index tuple
*/
static void
LoadInsertPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, BlockNumber insertPage)
{
*buf = ReadBuffer(index, insertPage);
LockBuffer(*buf, BUFFER_LOCK_EXCLUSIVE);
*state = GenericXLogStart(index);
*page = GenericXLogRegisterBuffer(*state, *buf, 0);
}
/*
* Insert a tuple into the index
*/
static void
InsertTuple(Relation rel, IndexTuple itup, Relation heapRel, Datum *values)
{
Buffer buf;
Page page;
GenericXLogState *state;
Size itemsz;
BlockNumber insertPage = InvalidBlockNumber;
ListInfo listInfo;
bool newPage = false;
/* Find the insert page - sets the page and list info */
FindInsertPage(rel, values, &insertPage, &listInfo);
Assert(BlockNumberIsValid(insertPage));
itemsz = MAXALIGN(IndexTupleSize(itup));
Assert(itemsz <= BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(IvfflatPageOpaqueData)));
LoadInsertPage(rel, &buf, &page, &state, insertPage);
/* Find a page to insert the item */
while (PageGetFreeSpace(page) < itemsz)
{
insertPage = IvfflatPageGetOpaque(page)->nextblkno;
if (BlockNumberIsValid(insertPage))
{
/* Move to next page */
GenericXLogAbort(state);
UnlockReleaseBuffer(buf);
LoadInsertPage(rel, &buf, &page, &state, insertPage);
}
else
{
/* Add a new page */
IvfflatAppendPage(rel, &buf, &page, &state, MAIN_FORKNUM);
insertPage = BufferGetBlockNumber(buf);
newPage = true;
}
}
/* Add to next offset */
if (PageAddItem(page, (Item) itup, itemsz, InvalidOffsetNumber, false, false) == InvalidOffsetNumber)
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(rel));
IvfflatCommitBuffer(buf, state);
/* Update the insert page */
if (newPage)
IvfflatUpdateList(rel, state, listInfo, insertPage, InvalidBlockNumber, MAIN_FORKNUM);
}
/*
* Insert a tuple into the index
*/
bool
ivfflatinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid,
Relation heap, IndexUniqueCheck checkUnique
#if PG_VERSION_NUM >= 100000
,IndexInfo *indexInfo
#endif
)
{
IndexTuple itup;
Datum value;
FmgrInfo *normprocinfo;
if (isnull[0])
return false;
value = values[0];
/* Normalize if needed */
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
if (normprocinfo != NULL)
{
if (!IvfflatNormValue(normprocinfo, index->rd_indcollation[0], &value, NULL))
return false;
}
itup = index_form_tuple(RelationGetDescr(index), &value, isnull);
itup->t_tid = *heap_tid;
InsertTuple(index, itup, heap, &value);
pfree(itup);
/* Clean up if we allocated a new value */
if (value != values[0])
pfree(DatumGetPointer(value));
return false;
}

479
src/ivfkmeans.c Normal file
View File

@@ -0,0 +1,479 @@
#include "postgres.h"
#include <float.h>
#include "ivfflat.h"
#include "miscadmin.h"
/*
* Initialize with kmeans++
*
* https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf
*/
static void
InitCenters(Relation index, VectorArray samples, VectorArray centers, double *lowerBound)
{
FmgrInfo *procinfo;
Oid collation;
int i;
int j;
double distance;
double sum;
double choice;
Vector *vec;
double *weight = palloc(samples->length * sizeof(double));
int numCenters = centers->maxlen;
int numSamples = samples->length;
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
collation = index->rd_indcollation[0];
/* Choose an initial center uniformly at random */
VectorArraySet(centers, 0, VectorArrayGet(samples, random() % samples->length));
centers->length++;
for (j = 0; j < numSamples; j++)
weight[j] = DBL_MAX;
for (i = 0; i < numCenters; i++)
{
CHECK_FOR_INTERRUPTS();
sum = 0.0;
for (j = 0; j < numSamples; j++)
{
vec = VectorArrayGet(samples, j);
/* Only need to compute distance for new center */
/* TODO Use triangle inequality to reduce distance calculations */
distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, i))));
/* Set lower bound */
lowerBound[j * numCenters + i] = distance;
/* Use distance squared for weighted probability distribution */
distance *= distance;
if (distance < weight[j])
weight[j] = distance;
sum += weight[j];
}
/* Only compute lower bound on last iteration */
if (i + 1 == numCenters)
break;
/* Choose new center using weighted probability distribution. */
choice = sum * (((double) random()) / MAX_RANDOM_VALUE);
for (j = 0; j < numSamples - 1; j++)
{
choice -= weight[j];
if (choice <= 0)
break;
}
VectorArraySet(centers, i + 1, VectorArrayGet(samples, j));
centers->length++;
}
pfree(weight);
}
/*
* Apply norm to vector
*/
static inline void
ApplyNorm(FmgrInfo *normprocinfo, Oid collation, Vector * vec)
{
int i;
double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(vec)));
/* TODO Handle zero norm */
if (norm > 0)
{
for (i = 0; i < vec->dim; i++)
vec->x[i] /= norm;
}
}
/*
* Compare vectors
*/
static int
CompareVectors(const void *a, const void *b)
{
return vector_cmp_internal((Vector *) a, (Vector *) b);
}
/*
* Quick approach if we have little data
*/
static void
QuickCenters(Relation index, VectorArray samples, VectorArray centers)
{
int i;
int j;
Vector *vec;
int dimensions = centers->dim;
Oid collation = index->rd_indcollation[0];
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
/* Copy existing vectors while avoiding duplicates */
qsort(samples->items, samples->length, VECTOR_SIZE(samples->dim), CompareVectors);
for (i = 0; i < samples->length; i++)
{
vec = VectorArrayGet(samples, i);
if (i == 0 || CompareVectors(vec, VectorArrayGet(samples, i - 1)) != 0)
{
VectorArraySet(centers, centers->length, vec);
centers->length++;
}
}
/* Fill remaining with random data */
while (centers->length < centers->maxlen)
{
vec = VectorArrayGet(centers, centers->length);
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
vec->dim = dimensions;
for (j = 0; j < dimensions; j++)
vec->x[j] = ((double) random()) / MAX_RANDOM_VALUE;
/* Normalize if needed (only needed for random centers) */
if (normprocinfo != NULL)
ApplyNorm(normprocinfo, collation, vec);
centers->length++;
}
}
/*
* Use Elkan for performance. This requires distance function to satisfy triangle inequality.
*
* We use L2 distance for L2 (not L2 squared like index scan)
* and angular distance for inner product and cosine distance
*
* https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf
*/
static void
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
{
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
Oid collation;
Vector *vec;
Vector *newCenter;
int iteration;
int j;
int k;
int dimensions = centers->dim;
int numCenters = centers->maxlen;
int numSamples = samples->length;
VectorArray newCenters;
int *centerCounts;
int *closestCenters;
double *lowerBound;
double *upperBound;
double *s;
double *halfcdist;
double *newcdist;
int changes;
double minDistance;
int closestCenter;
double distance;
bool rj;
bool rjreset;
double dxcx;
double dxc;
/* Set support functions */
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
collation = index->rd_indcollation[0];
/* Allocate space */
centerCounts = palloc(sizeof(int) * numCenters);
closestCenters = palloc(sizeof(int) * numSamples);
lowerBound = palloc(sizeof(double) * numSamples * numCenters);
upperBound = palloc(sizeof(double) * numSamples);
s = palloc(sizeof(double) * numCenters);
halfcdist = palloc(sizeof(double) * numCenters * numCenters);
newcdist = palloc(sizeof(double) * numCenters);
newCenters = VectorArrayInit(numCenters, dimensions);
for (j = 0; j < numCenters; j++)
{
vec = VectorArrayGet(newCenters, j);
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
vec->dim = dimensions;
}
/* Pick initial centers */
InitCenters(index, samples, centers, lowerBound);
/* Assign each x to its closest initial center c(x) = argmin d(x,c) */
for (j = 0; j < numSamples; j++)
{
minDistance = DBL_MAX;
closestCenter = -1;
vec = VectorArrayGet(samples, j);
/* Find closest center */
for (k = 0; k < numCenters; k++)
{
/* TODO Use Lemma 1 in k-means++ initialization */
distance = lowerBound[j * numCenters + k];
if (distance < minDistance)
{
minDistance = distance;
closestCenter = k;
}
}
upperBound[j] = minDistance;
closestCenters[j] = closestCenter;
}
/* Give 500 iterations to converge */
for (iteration = 0; iteration < 500; iteration++)
{
/* Can take a while, so ensure we can interrupt */
CHECK_FOR_INTERRUPTS();
changes = 0;
/* Step 1: For all centers, compute distance */
for (j = 0; j < numCenters; j++)
{
vec = VectorArrayGet(centers, j);
for (k = j + 1; k < numCenters; k++)
{
distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k))));
halfcdist[j * numCenters + k] = distance;
halfcdist[k * numCenters + j] = distance;
}
}
/* For all centers c, compute s(c) */
for (j = 0; j < numCenters; j++)
{
minDistance = DBL_MAX;
for (k = 0; k < numCenters; k++)
{
if (j == k)
continue;
distance = halfcdist[j * numCenters + k];
if (distance < minDistance)
minDistance = distance;
}
s[j] = minDistance;
}
rjreset = iteration != 0;
for (j = 0; j < numSamples; j++)
{
/* Step 2: Identify all points x such that u(x) <= s(c(x)) */
if (upperBound[j] <= s[closestCenters[j]])
continue;
rj = rjreset;
for (k = 0; k < numCenters; k++)
{
/* Step 3: For all remaining points x and centers c */
if (k == closestCenters[j])
continue;
if (upperBound[j] <= lowerBound[j * numCenters + k])
continue;
if (upperBound[j] <= halfcdist[closestCenters[j] * numCenters + k])
continue;
vec = VectorArrayGet(samples, j);
/* Step 3a */
if (rj)
{
dxcx = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, closestCenters[j]))));
/* d(x,c(x)) computed, which is a form of d(x,c) */
lowerBound[j * numCenters + closestCenters[j]] = dxcx;
upperBound[j] = dxcx;
rj = false;
}
else
dxcx = upperBound[j];
/* Step 3b */
if (dxcx > lowerBound[j * numCenters + k] || dxcx > halfcdist[closestCenters[j] * numCenters + k])
{
dxc = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k))));
/* d(x,c) calculated */
lowerBound[j * numCenters + k] = dxc;
if (dxc < dxcx)
{
closestCenters[j] = k;
/* c(x) changed */
upperBound[j] = dxc;
changes++;
}
}
}
}
/* Step 4: For each center c, let m(c) be mean of all points assigned */
for (j = 0; j < numCenters; j++)
{
vec = VectorArrayGet(newCenters, j);
for (k = 0; k < dimensions; k++)
vec->x[k] = 0.0;
centerCounts[j] = 0;
}
for (j = 0; j < numSamples; j++)
{
vec = VectorArrayGet(samples, j);
closestCenter = closestCenters[j];
/* Increment sum and count of closest center */
newCenter = VectorArrayGet(newCenters, closestCenter);
for (k = 0; k < dimensions; k++)
newCenter->x[k] += vec->x[k];
centerCounts[closestCenter] += 1;
}
for (j = 0; j < numCenters; j++)
{
vec = VectorArrayGet(newCenters, j);
if (centerCounts[j] > 0)
{
for (k = 0; k < dimensions; k++)
vec->x[k] /= centerCounts[j];
}
else
{
/* TODO Handle empty centers properly */
for (k = 0; k < dimensions; k++)
vec->x[k] = ((double) random()) / MAX_RANDOM_VALUE;
}
/* Normalize if needed */
if (normprocinfo != NULL)
ApplyNorm(normprocinfo, collation, vec);
}
/* Step 5 */
for (j = 0; j < numCenters; j++)
newcdist[j] = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(VectorArrayGet(centers, j)), PointerGetDatum(VectorArrayGet(newCenters, j))));
for (j = 0; j < numSamples; j++)
{
for (k = 0; k < numCenters; k++)
{
distance = lowerBound[j * numCenters + k] - newcdist[k];
if (distance < 0)
distance = 0;
lowerBound[j * numCenters + k] = distance;
}
}
/* Step 6 */
/* We reset r(x) before Step 3 in the next iteration */
for (j = 0; j < numSamples; j++)
upperBound[j] += newcdist[closestCenters[j]];
/* Step 7 */
for (j = 0; j < numCenters; j++)
memcpy(VectorArrayGet(centers, j), VectorArrayGet(newCenters, j), VECTOR_SIZE(dimensions));
if (changes == 0 && iteration != 0)
break;
}
pfree(newCenters);
pfree(centerCounts);
pfree(closestCenters);
pfree(lowerBound);
pfree(upperBound);
pfree(s);
pfree(halfcdist);
pfree(newcdist);
}
/*
* Detect issues with centers
*/
static void
CheckCenters(Relation index, VectorArray centers)
{
FmgrInfo *normprocinfo;
Oid collation;
int i;
double norm;
if (centers->length != centers->maxlen)
elog(ERROR, "Not enough centers. Please report a bug.");
/* Ensure no duplicate centers */
/* Fine to sort in-place */
qsort(centers->items, centers->length, VECTOR_SIZE(centers->dim), CompareVectors);
for (i = 1; i < centers->length; i++)
{
if (CompareVectors(VectorArrayGet(centers, i), VectorArrayGet(centers, i - 1)) == 0)
elog(ERROR, "Duplicate centers detected. Please report a bug.");
}
/* Ensure no zero vectors for cosine distance */
/* Check NORM_PROC instead of KMEANS_NORM_PROC */
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
if (normprocinfo != NULL)
{
collation = index->rd_indcollation[0];
for (i = 0; i < centers->length; i++)
{
norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(VectorArrayGet(centers, i))));
if (norm == 0)
elog(ERROR, "Zero norm detected. Please report a bug.");
}
}
}
/*
* Perform naive k-means centering
* We use spherical k-means for inner product and cosine
*/
void
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers)
{
if (samples->length <= centers->maxlen)
QuickCenters(index, samples, centers);
else
ElkanKmeans(index, samples, centers);
CheckCenters(index, centers);
}

327
src/ivfscan.c Normal file
View File

@@ -0,0 +1,327 @@
#include "postgres.h"
#include "access/relscan.h"
#include "ivfflat.h"
#include "miscadmin.h"
#include "storage/bufmgr.h"
#if PG_VERSION_NUM >= 110000
#include "catalog/pg_operator_d.h"
#include "catalog/pg_type_d.h"
#else
#include "catalog/pg_operator.h"
#include "catalog/pg_type.h"
#endif
/*
* Compare list distances
*/
static int
CompareLists(const void *a, const void *b)
{
double diff = (((IvfflatScanList *) a)->distance - ((IvfflatScanList *) b)->distance);
if (diff > 0)
return 1;
if (diff < 0)
return -1;
return 0;
}
/*
* Get lists and sort by distance
*/
static void
GetScanLists(IndexScanDesc scan, Datum value)
{
Buffer cbuf;
Page cpage;
IvfflatList list;
OffsetNumber offno;
OffsetNumber maxoffno;
BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO;
int listCount = 0;
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
double distance;
/* Search all list pages */
while (BlockNumberIsValid(nextblkno))
{
cbuf = ReadBuffer(scan->indexRelation, nextblkno);
LockBuffer(cbuf, BUFFER_LOCK_SHARE);
cpage = BufferGetPage(cbuf);
maxoffno = PageGetMaxOffsetNumber(cpage);
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, offno));
/* Use procinfo from the index instead of scan key for performance */
distance = DatumGetFloat8(FunctionCall2Coll(so->procinfo, so->collation, PointerGetDatum(&list->center), value));
so->lists[listCount].startPage = list->startPage;
so->lists[listCount].distance = distance;
listCount++;
}
nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno;
UnlockReleaseBuffer(cbuf);
}
/* Sort by distance */
qsort(so->lists, listCount, sizeof(IvfflatScanList), CompareLists);
if (so->probes > listCount)
so->probes = listCount;
}
/*
* Get items
*/
static void
GetScanItems(IndexScanDesc scan, Datum value)
{
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
Buffer buf;
Page page;
IndexTuple itup;
BlockNumber searchPage;
OffsetNumber offno;
OffsetNumber maxoffno;
Datum datum;
bool isnull;
int i;
TupleDesc tupdesc = RelationGetDescr(scan->indexRelation);
#if PG_VERSION_NUM >= 120000
TupleTableSlot *slot = MakeSingleTupleTableSlot(so->tupdesc, &TTSOpsVirtual);
#else
TupleTableSlot *slot = MakeSingleTupleTableSlot(so->tupdesc);
#endif
/*
* Reuse same set of shared buffers for scan
*
* See postgres/src/backend/storage/buffer/README for description
*/
BufferAccessStrategy bas = GetAccessStrategy(BAS_BULKREAD);
/* Search closest probes lists */
for (i = 0; i < so->probes; i++)
{
searchPage = so->lists[i].startPage;
/* Search all entry pages for list */
while (BlockNumberIsValid(searchPage))
{
buf = ReadBufferExtended(scan->indexRelation, MAIN_FORKNUM, searchPage, RBM_NORMAL, bas);
LockBuffer(buf, BUFFER_LOCK_SHARE);
page = BufferGetPage(buf);
maxoffno = PageGetMaxOffsetNumber(page);
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
itup = (IndexTuple) PageGetItem(page, PageGetItemId(page, offno));
datum = index_getattr(itup, 1, tupdesc, &isnull);
/*
* Add virtual tuple
*
* Use procinfo from the index instead of scan key for
* performance
*/
ExecClearTuple(slot);
slot->tts_values[0] = FunctionCall2Coll(so->procinfo, so->collation, datum, value);
slot->tts_isnull[0] = false;
slot->tts_values[1] = Int32GetDatum((int) ItemPointerGetBlockNumberNoCheck(&itup->t_tid));
slot->tts_isnull[1] = false;
slot->tts_values[2] = Int32GetDatum((int) ItemPointerGetOffsetNumberNoCheck(&itup->t_tid));
slot->tts_isnull[2] = false;
slot->tts_values[3] = Int32GetDatum((int) searchPage);
slot->tts_isnull[3] = false;
ExecStoreVirtualTuple(slot);
tuplesort_puttupleslot(so->sortstate, slot);
}
searchPage = IvfflatPageGetOpaque(page)->nextblkno;
UnlockReleaseBuffer(buf);
}
}
}
/*
* Prepare for an index scan
*/
IndexScanDesc
ivfflatbeginscan(Relation index, int nkeys, int norderbys)
{
IndexScanDesc scan;
IvfflatScanOpaque so;
int lists;
AttrNumber attNums[] = {1};
Oid sortOperators[] = {Float8LessOperator};
Oid sortCollations[] = {InvalidOid};
bool nullsFirstFlags[] = {false};
scan = RelationGetIndexScan(index, nkeys, norderbys);
lists = IvfflatGetLists(scan->indexRelation);
so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + lists * sizeof(IvfflatScanList));
so->buf = InvalidBuffer;
so->first = true;
/* Set support functions */
so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
so->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
so->collation = index->rd_indcollation[0];
/* Create tuple description for sorting */
#if PG_VERSION_NUM >= 120000
so->tupdesc = CreateTemplateTupleDesc(4);
#else
so->tupdesc = CreateTemplateTupleDesc(4, false);
#endif
TupleDescInitEntry(so->tupdesc, (AttrNumber) 1, "distance", FLOAT8OID, -1, 0);
TupleDescInitEntry(so->tupdesc, (AttrNumber) 2, "blkno", INT4OID, -1, 0);
TupleDescInitEntry(so->tupdesc, (AttrNumber) 3, "offset", INT4OID, -1, 0);
TupleDescInitEntry(so->tupdesc, (AttrNumber) 4, "indexblkno", INT4OID, -1, 0);
/* Prep sort */
#if PG_VERSION_NUM >= 110000
so->sortstate = tuplesort_begin_heap(so->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, work_mem, NULL, false);
#else
so->sortstate = tuplesort_begin_heap(so->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, work_mem, false);
#endif
#if PG_VERSION_NUM >= 120000
so->slot = MakeSingleTupleTableSlot(so->tupdesc, &TTSOpsMinimalTuple);
#else
so->slot = MakeSingleTupleTableSlot(so->tupdesc);
#endif
scan->opaque = so;
return scan;
}
/*
* Start or restart an index scan
*/
void
ivfflatrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys)
{
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
#if PG_VERSION_NUM >= 130000
if (!so->first)
tuplesort_reset(so->sortstate);
#endif
so->first = true;
so->probes = ivfflat_probes;
if (keys && scan->numberOfKeys > 0)
memmove(scan->keyData, keys, scan->numberOfKeys * sizeof(ScanKeyData));
if (orderbys && scan->numberOfOrderBys > 0)
memmove(scan->orderByData, orderbys, scan->numberOfOrderBys * sizeof(ScanKeyData));
}
/*
* Fetch the next tuple in the given scan
*/
bool
ivfflatgettuple(IndexScanDesc scan, ScanDirection dir)
{
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
/*
* Index can be used to scan backward, but Postgres doesn't support
* backward scan on operators
*/
Assert(ScanDirectionIsForward(dir));
if (so->first)
{
Datum value;
/* No items will match if null */
if (scan->orderByData->sk_flags & SK_ISNULL)
return false;
value = scan->orderByData->sk_argument;
if (so->normprocinfo != NULL)
{
/* No items will match if normalization fails */
if (!IvfflatNormValue(so->normprocinfo, so->collation, &value, NULL))
return false;
}
GetScanLists(scan, value);
GetScanItems(scan, value);
tuplesort_performsort(so->sortstate);
so->first = false;
/* Clean up if we allocated a new value */
if (value != scan->orderByData->sk_argument)
pfree(DatumGetPointer(value));
}
#if PG_VERSION_NUM >= 100000
if (tuplesort_gettupleslot(so->sortstate, true, false, so->slot, NULL))
#else
if (tuplesort_gettupleslot(so->sortstate, true, so->slot, NULL))
#endif
{
BlockNumber blkno = DatumGetInt32(slot_getattr(so->slot, 2, &so->isnull));
OffsetNumber offset = DatumGetInt32(slot_getattr(so->slot, 3, &so->isnull));
BlockNumber indexblkno = DatumGetInt32(slot_getattr(so->slot, 4, &so->isnull));
#if PG_VERSION_NUM >= 120000
ItemPointerSet(&scan->xs_heaptid, blkno, offset);
#else
ItemPointerSet(&scan->xs_ctup.t_self, blkno, offset);
#endif
if (BufferIsValid(so->buf))
ReleaseBuffer(so->buf);
/*
* An index scan must maintain a pin on the index page holding the
* item last returned by amgettuple
*
* https://www.postgresql.org/docs/current/index-locking.html
*/
so->buf = ReadBuffer(scan->indexRelation, indexblkno);
scan->xs_recheckorderby = false;
return true;
}
return false;
}
/*
* End a scan and release resources
*/
void
ivfflatendscan(IndexScanDesc scan)
{
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
/* Release pin */
if (BufferIsValid(so->buf))
ReleaseBuffer(so->buf);
tuplesort_end(so->sortstate);
pfree(so);
scan->opaque = NULL;
}

176
src/ivfutils.c Normal file
View File

@@ -0,0 +1,176 @@
#include "postgres.h"
#include "ivfflat.h"
#include "storage/bufmgr.h"
#include "vector.h"
/*
* Allocate a vector array
*/
VectorArray
VectorArrayInit(int maxlen, int dimensions)
{
VectorArray res = palloc0(VECTOR_ARRAY_SIZE(maxlen, dimensions));
res->length = 0;
res->maxlen = maxlen;
res->dim = dimensions;
return res;
}
/*
* Print vector array - useful for debugging
*/
void
PrintVectorArray(char *msg, VectorArray arr)
{
int i;
for (i = 0; i < arr->length; i++)
PrintVector(msg, VectorArrayGet(arr, i));
}
/*
* Get the number of lists in the index
*/
int
IvfflatGetLists(Relation index)
{
IvfflatOptions *opts = (IvfflatOptions *) index->rd_options;
if (opts)
return opts->lists;
return IVFFLAT_DEFAULT_LISTS;
}
/*
* Get proc
*/
FmgrInfo *
IvfflatOptionalProcInfo(Relation rel, uint16 procnum)
{
if (!OidIsValid(index_getprocid(rel, 1, procnum)))
return NULL;
return index_getprocinfo(rel, 1, procnum);
}
/*
* Divide by the norm
*
* Returns false if value should not be indexed
*
* The caller needs to free the pointer stored in value
* if it's different than the original value
*/
bool
IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result)
{
Vector *v;
int i;
double norm;
norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
if (norm > 0)
{
v = (Vector *) DatumGetPointer(*value);
if (result == NULL)
result = InitVector(v->dim);
for (i = 0; i < v->dim; i++)
result->x[i] = v->x[i] / norm;
*value = PointerGetDatum(result);
return true;
}
return false;
}
/*
* New buffer
*/
Buffer
IvfflatNewBuffer(Relation index, ForkNumber forkNum)
{
Buffer buf = ReadBufferExtended(index, forkNum, P_NEW, RBM_NORMAL, NULL);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
return buf;
}
/*
* Init page
*/
void
IvfflatInitPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state)
{
*state = GenericXLogStart(index);
*page = GenericXLogRegisterBuffer(*state, *buf, GENERIC_XLOG_FULL_IMAGE);
PageInit(*page, BufferGetPageSize(*buf), sizeof(IvfflatPageOpaqueData));
IvfflatPageGetOpaque(*page)->nextblkno = InvalidBlockNumber;
IvfflatPageGetOpaque(*page)->page_id = IVFFLAT_PAGE_ID;
}
/*
* Commit buffer
*/
void
IvfflatCommitBuffer(Buffer buf, GenericXLogState *state)
{
MarkBufferDirty(buf);
GenericXLogFinish(state);
UnlockReleaseBuffer(buf);
}
/*
* Add a new page
*
* The order is very important!!
*/
void
IvfflatAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, ForkNumber forkNum)
{
Buffer prevbuf = *buf;
/* Get new buffer */
*buf = IvfflatNewBuffer(index, forkNum);
/* Update and commit previous buffer */
IvfflatPageGetOpaque(*page)->nextblkno = BufferGetBlockNumber(*buf);
IvfflatCommitBuffer(prevbuf, *state);
/* Init new page */
IvfflatInitPage(index, buf, page, state);
}
/*
* Update the start or insert page of a list
*/
void
IvfflatUpdateList(Relation index, GenericXLogState *state, ListInfo listInfo,
BlockNumber insertPage, BlockNumber startPage, ForkNumber forkNum)
{
Buffer buf;
Page page;
IvfflatList list;
buf = ReadBufferExtended(index, forkNum, listInfo.blkno, RBM_NORMAL, NULL);
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
list = (IvfflatList) PageGetItem(page, PageGetItemId(page, listInfo.offno));
if (BlockNumberIsValid(insertPage))
list->insertPage = insertPage;
if (BlockNumberIsValid(startPage))
list->startPage = startPage;
/* Could only commit if changed, but extra complexity isn't needed */
IvfflatCommitBuffer(buf, state);
}

151
src/ivfvacuum.c Normal file
View File

@@ -0,0 +1,151 @@
#include "postgres.h"
#include "commands/vacuum.h"
#include "ivfflat.h"
#include "storage/bufmgr.h"
/*
* Bulk delete tuples from the index
*/
IndexBulkDeleteResult *
ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats,
IndexBulkDeleteCallback callback, void *callback_state)
{
Relation index = info->index;
Buffer cbuf;
Page cpage;
Buffer buf;
Page page;
IvfflatList list;
IndexTuple itup;
ItemPointer htup;
OffsetNumber deletable[MaxOffsetNumber];
int ndeletable;
OffsetNumber startPages[MaxOffsetNumber];
BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO;
BlockNumber searchPage;
BlockNumber insertPage;
GenericXLogState *state;
OffsetNumber coffno;
OffsetNumber cmaxoffno;
OffsetNumber offno;
OffsetNumber maxoffno;
ListInfo listInfo;
BufferAccessStrategy bas = GetAccessStrategy(BAS_BULKREAD);
if (stats == NULL)
stats = (IndexBulkDeleteResult *) palloc0(sizeof(IndexBulkDeleteResult));
/* Iterate over list pages */
while (BlockNumberIsValid(nextblkno))
{
cbuf = ReadBuffer(index, nextblkno);
LockBuffer(cbuf, BUFFER_LOCK_SHARE);
cpage = BufferGetPage(cbuf);
cmaxoffno = PageGetMaxOffsetNumber(cpage);
/* Iterate over lists */
for (coffno = FirstOffsetNumber; coffno <= cmaxoffno; coffno = OffsetNumberNext(coffno))
{
list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, coffno));
startPages[coffno - FirstOffsetNumber] = list->startPage;
}
listInfo.blkno = nextblkno;
nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno;
UnlockReleaseBuffer(cbuf);
for (coffno = FirstOffsetNumber; coffno <= cmaxoffno; coffno = OffsetNumberNext(coffno))
{
searchPage = startPages[coffno - FirstOffsetNumber];
insertPage = InvalidBlockNumber;
/* Iterate over entry pages */
while (BlockNumberIsValid(searchPage))
{
vacuum_delay_point();
buf = ReadBufferExtended(index, MAIN_FORKNUM, searchPage, RBM_NORMAL, bas);
/*
* ambulkdelete cannot delete entries from pages that are
* pinned by other backends
*
* https://www.postgresql.org/docs/current/index-locking.html
*/
LockBufferForCleanup(buf);
state = GenericXLogStart(index);
page = GenericXLogRegisterBuffer(state, buf, 0);
maxoffno = PageGetMaxOffsetNumber(page);
ndeletable = 0;
/* Find deleted tuples */
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
{
itup = (IndexTuple) PageGetItem(page, PageGetItemId(page, offno));
htup = &(itup->t_tid);
if (callback(htup, callback_state))
{
deletable[ndeletable++] = offno;
stats->tuples_removed++;
}
else
stats->num_index_tuples++;
}
searchPage = IvfflatPageGetOpaque(page)->nextblkno;
if (ndeletable > 0)
{
/* Delete tuples */
PageIndexMultiDelete(page, deletable, ndeletable);
MarkBufferDirty(buf);
GenericXLogFinish(state);
/* Set to first free page */
if (!BlockNumberIsValid(insertPage))
insertPage = searchPage;
}
else
GenericXLogAbort(state);
UnlockReleaseBuffer(buf);
}
/*
* Update after all tuples deleted.
*
* We don't add or delete items from lists pages, so offset won't
* change.
*/
if (!BlockNumberIsValid(insertPage))
{
listInfo.offno = coffno;
IvfflatUpdateList(index, state, listInfo, insertPage, InvalidBlockNumber, MAIN_FORKNUM);
}
}
}
return stats;
}
/*
* Clean up after a VACUUM operation
*/
IndexBulkDeleteResult *
ivfflatvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats)
{
Relation rel = info->index;
if (stats == NULL)
return NULL;
stats->num_pages = RelationGetNumberOfBlocks(rel);
return stats;
}

610
src/vector.c Normal file
View File

@@ -0,0 +1,610 @@
#include "postgres.h"
#include <math.h>
#include "vector.h"
#include "fmgr.h"
#include "catalog/pg_type.h"
#include "lib/stringinfo.h"
#include "utils/array.h"
#include "utils/builtins.h"
#include "utils/lsyscache.h"
#if PG_VERSION_NUM >= 120000
#include "utils/float.h"
#endif
PG_MODULE_MAGIC;
/*
* Ensure same dimensions
*/
static inline void
CheckDims(Vector * a, Vector * b)
{
if (a->dim != b->dim)
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("different vector dimensions %d and %d", a->dim, b->dim)));
}
/*
* Ensure expected dimension
*/
static inline void
CheckExpectedDim(int32 typmod, int dim)
{
if (typmod != -1 && typmod != dim)
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("expected %d dimensions, not %d", typmod, dim)));
}
/*
* Ensure finite elements
*/
static inline void
CheckElement(float value)
{
if (isnan(value))
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("NaN not allowed in vector")));
if (isinf(value))
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("infinite value not allowed in vector")));
}
/*
* Print vector - useful for debugging
*/
void
PrintVector(char *msg, Vector * vector)
{
StringInfoData buf;
int dim = vector->dim;
int i;
initStringInfo(&buf);
appendStringInfoChar(&buf, '[');
for (i = 0; i < dim; i++)
{
if (i > 0)
appendStringInfoString(&buf, ",");
appendStringInfoString(&buf, float8out_internal(vector->x[i]));
}
appendStringInfoChar(&buf, ']');
elog(INFO, "%s = %s", msg, buf.data);
}
/*
* Convert textual representation to internal representation
*/
PG_FUNCTION_INFO_V1(vector_in);
Datum
vector_in(PG_FUNCTION_ARGS)
{
char *str = PG_GETARG_CSTRING(0);
int32 typmod = PG_GETARG_INT32(2);
int i;
double x[VECTOR_MAX_DIM];
int dim = 0;
char *pt;
char *stringEnd;
Vector *result;
if (*str != '[')
ereport(ERROR,
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
errmsg("malformed vector literal: \"%s\"", str),
errdetail("Vector contents must start with \"[\".")));
str++;
pt = strtok(str, ",");
stringEnd = pt;
while (pt != NULL && *stringEnd != ']')
{
if (dim == VECTOR_MAX_DIM)
ereport(ERROR,
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
errmsg("vector cannot have more than %d dimensions", VECTOR_MAX_DIM)));
x[dim] = strtod(pt, &stringEnd);
CheckElement(x[dim]);
dim++;
if (stringEnd == pt)
ereport(ERROR,
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
errmsg("invalid input syntax for type vector: \"%s\"", pt)));
if (*stringEnd != '\0' && *stringEnd != ']')
ereport(ERROR,
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
errmsg("invalid input syntax for type vector: \"%s\"", pt)));
pt = strtok(NULL, ",");
}
if (*stringEnd != ']')
ereport(ERROR,
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
errmsg("malformed vector literal"),
errdetail("Unexpected end of input.")));
if (stringEnd[1] != '\0')
ereport(ERROR,
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
errmsg("malformed vector literal"),
errdetail("Junk after closing right brace.")));
if (dim < 1)
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("vector must have at least 1 dimension")));
CheckExpectedDim(typmod, dim);
result = InitVector(dim);
for (i = 0; i < dim; i++)
result->x[i] = x[i];
PG_RETURN_POINTER(result);
}
/*
* Convert internal representation to textual representation
*/
PG_FUNCTION_INFO_V1(vector_out);
Datum
vector_out(PG_FUNCTION_ARGS)
{
Vector *vector = PG_GETARG_VECTOR_P(0);
StringInfoData buf;
int dim = vector->dim;
int i;
initStringInfo(&buf);
appendStringInfoChar(&buf, '[');
for (i = 0; i < dim; i++)
{
if (i > 0)
appendStringInfoString(&buf, ",");
appendStringInfoString(&buf, float8out_internal(vector->x[i]));
}
appendStringInfoChar(&buf, ']');
PG_FREE_IF_COPY(vector, 0);
PG_RETURN_CSTRING(buf.data);
}
/*
* Convert type modifier
*/
PG_FUNCTION_INFO_V1(vector_typmod_in);
Datum
vector_typmod_in(PG_FUNCTION_ARGS)
{
ArrayType *ta = PG_GETARG_ARRAYTYPE_P(0);
int32 *tl;
int n;
tl = ArrayGetIntegerTypmods(ta, &n);
if (n != 1)
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("invalid type modifier")));
if (*tl < 1)
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("dimensions for type vector must be at least 1")));
if (*tl > VECTOR_MAX_DIM)
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("dimensions for type vector cannot exceed %d", VECTOR_MAX_DIM)));
PG_RETURN_INT32(*tl);
}
/*
* Convert vector to vector
*/
PG_FUNCTION_INFO_V1(vector);
Datum
vector(PG_FUNCTION_ARGS)
{
Vector *arg = PG_GETARG_VECTOR_P(0);
int32 typmod = PG_GETARG_INT32(1);
CheckExpectedDim(typmod, arg->dim);
PG_RETURN_POINTER(arg);
}
/*
* Convert array to vector
*/
PG_FUNCTION_INFO_V1(array_to_vector);
Datum
array_to_vector(PG_FUNCTION_ARGS)
{
ArrayType *array = PG_GETARG_ARRAYTYPE_P(0);
int32 typmod = PG_GETARG_INT32(1);
int i;
Vector *result;
int16 typlen;
bool typbyval;
char typalign;
Datum *elemsp;
bool *nullsp;
int nelemsp;
if (ARR_NDIM(array) > 1)
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("array must be 1-D")));
get_typlenbyvalalign(ARR_ELEMTYPE(array), &typlen, &typbyval, &typalign);
deconstruct_array(array, ARR_ELEMTYPE(array), typlen, typbyval, typalign, &elemsp, &nullsp, &nelemsp);
if (typmod == -1)
{
if (nelemsp < 1)
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("vector must have at least 1 dimension")));
if (nelemsp > VECTOR_MAX_DIM)
ereport(ERROR,
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
errmsg("vector cannot have more than %d dimensions", VECTOR_MAX_DIM)));
}
else
CheckExpectedDim(typmod, nelemsp);
result = InitVector(nelemsp);
for (i = 0; i < nelemsp; i++)
{
if (nullsp[i])
ereport(ERROR,
(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
errmsg("array must not containing NULLs")));
if (ARR_ELEMTYPE(array) == INT4OID)
result->x[i] = DatumGetInt32(elemsp[i]);
else if (ARR_ELEMTYPE(array) == FLOAT8OID)
result->x[i] = DatumGetFloat8(elemsp[i]);
else if (ARR_ELEMTYPE(array) == FLOAT4OID)
result->x[i] = DatumGetFloat4(elemsp[i]);
else
ereport(ERROR,
(errcode(ERRCODE_DATA_EXCEPTION),
errmsg("unsupported array type")));
CheckElement(result->x[i]);
}
PG_RETURN_POINTER(result);
}
/*
* Get the L2 distance between vectors
*/
PG_FUNCTION_INFO_V1(l2_distance);
Datum
l2_distance(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
double distance = 0.0;
CheckDims(a, b);
for (int i = 0; i < a->dim; i++)
distance += pow(a->x[i] - b->x[i], 2);
PG_RETURN_FLOAT8(sqrt(distance));
}
/*
* Get the L2 squared distance between vectors
* This saves a sqrt calculation
*/
PG_FUNCTION_INFO_V1(vector_l2_squared_distance);
Datum
vector_l2_squared_distance(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
double distance = 0.0;
CheckDims(a, b);
for (int i = 0; i < a->dim; i++)
distance += pow(a->x[i] - b->x[i], 2);
PG_RETURN_FLOAT8(distance);
}
/*
* Get the inner product of two vectors
*/
PG_FUNCTION_INFO_V1(inner_product);
Datum
inner_product(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
double distance = 0.0;
CheckDims(a, b);
for (int i = 0; i < a->dim; i++)
distance += a->x[i] * b->x[i];
PG_RETURN_FLOAT8(distance);
}
/*
* Get the negative inner product of two vectors
*/
PG_FUNCTION_INFO_V1(vector_negative_inner_product);
Datum
vector_negative_inner_product(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
double distance = 0.0;
CheckDims(a, b);
for (int i = 0; i < a->dim; i++)
distance += a->x[i] * b->x[i];
PG_RETURN_FLOAT8(distance * -1);
}
/*
* Get the cosine distance between two vectors
*/
PG_FUNCTION_INFO_V1(cosine_distance);
Datum
cosine_distance(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
double distance = 0.0;
double norma = 0.0;
double normb = 0.0;
CheckDims(a, b);
for (int i = 0; i < a->dim; i++)
{
distance += a->x[i] * b->x[i];
norma += pow(a->x[i], 2);
normb += pow(b->x[i], 2);
}
PG_RETURN_FLOAT8(1 - (distance / (sqrt(norma) * sqrt(normb))));
}
/*
* Get the distance for spherical k-means
* Currently uses angular distance since needs to satisfy triangle inequality
* Assumes inputs are unit vectors (skips norm)
*/
PG_FUNCTION_INFO_V1(vector_spherical_distance);
Datum
vector_spherical_distance(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
double distance = 0.0;
CheckDims(a, b);
for (int i = 0; i < a->dim; i++)
distance += a->x[i] * b->x[i];
/* Prevent NaN with acos with loss of precision */
if (distance > 1)
distance = 1;
else if (distance < -1)
distance = -1;
PG_RETURN_FLOAT8(acos(distance) / M_PI);
}
/*
* Get the dimensions of a vector
*/
PG_FUNCTION_INFO_V1(vector_dims);
Datum
vector_dims(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
PG_RETURN_INT32(a->dim);
}
/*
* Get the L2 norm of a vector
*/
PG_FUNCTION_INFO_V1(vector_norm);
Datum
vector_norm(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
double norm = 0.0;
for (int i = 0; i < a->dim; i++)
norm += pow(a->x[i], 2);
PG_RETURN_FLOAT8(sqrt(norm));
}
/*
* Add vectors
*/
PG_FUNCTION_INFO_V1(vector_add);
Datum
vector_add(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
Vector *result;
int i;
CheckDims(a, b);
result = InitVector(a->dim);
for (i = 0; i < a->dim; i++)
result->x[i] = a->x[i] + b->x[i];
PG_RETURN_POINTER(result);
}
/*
* Subtract vectors
*/
PG_FUNCTION_INFO_V1(vector_sub);
Datum
vector_sub(PG_FUNCTION_ARGS)
{
Vector *a = PG_GETARG_VECTOR_P(0);
Vector *b = PG_GETARG_VECTOR_P(1);
Vector *result;
int i;
CheckDims(a, b);
result = InitVector(a->dim);
for (i = 0; i < a->dim; i++)
result->x[i] = a->x[i] - b->x[i];
PG_RETURN_POINTER(result);
}
/*
* Internal helper to compare vectors
*/
int
vector_cmp_internal(Vector * a, Vector * b)
{
int i;
CheckDims(a, b);
for (i = 0; i < a->dim; i++)
{
if (a->x[i] < b->x[i])
return -1;
if (a->x[i] > b->x[i])
return 1;
}
return 0;
}
/*
* Less than
*/
PG_FUNCTION_INFO_V1(vector_lt);
Datum
vector_lt(PG_FUNCTION_ARGS)
{
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
PG_RETURN_BOOL(vector_cmp_internal(a, b) < 0);
}
/*
* Less than or equal
*/
PG_FUNCTION_INFO_V1(vector_le);
Datum
vector_le(PG_FUNCTION_ARGS)
{
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
PG_RETURN_BOOL(vector_cmp_internal(a, b) <= 0);
}
/*
* Equal
*/
PG_FUNCTION_INFO_V1(vector_eq);
Datum
vector_eq(PG_FUNCTION_ARGS)
{
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
PG_RETURN_BOOL(vector_cmp_internal(a, b) == 0);
}
/*
* Not equal
*/
PG_FUNCTION_INFO_V1(vector_ne);
Datum
vector_ne(PG_FUNCTION_ARGS)
{
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
PG_RETURN_BOOL(vector_cmp_internal(a, b) != 0);
}
/*
* Greater than or equal
*/
PG_FUNCTION_INFO_V1(vector_ge);
Datum
vector_ge(PG_FUNCTION_ARGS)
{
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
PG_RETURN_BOOL(vector_cmp_internal(a, b) >= 0);
}
/*
* Greater than
*/
PG_FUNCTION_INFO_V1(vector_gt);
Datum
vector_gt(PG_FUNCTION_ARGS)
{
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
PG_RETURN_BOOL(vector_cmp_internal(a, b) > 0);
}
/*
* Compare vectors
*/
PG_FUNCTION_INFO_V1(vector_cmp);
Datum
vector_cmp(PG_FUNCTION_ARGS)
{
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
PG_RETURN_INT32(vector_cmp_internal(a, b));
}

41
src/vector.h Normal file
View File

@@ -0,0 +1,41 @@
#ifndef VECTOR_H
#define VECTOR_H
#include "postgres.h"
#define VECTOR_MAX_DIM 1024
#define VECTOR_SIZE(_dim) (offsetof(Vector, x) + sizeof(float)*(_dim))
#define DatumGetVector(x) ((Vector *) PG_DETOAST_DATUM(x))
#define PG_GETARG_VECTOR_P(x) DatumGetVector(PG_GETARG_DATUM(x))
#define PG_RETURN_VECTOR_P(x) PG_RETURN_POINTER(x)
typedef struct Vector
{
int32 vl_len_; /* varlena header (do not touch directly!) */
int16 dim; /* number of dimensions */
int16 unused;
float x[FLEXIBLE_ARRAY_MEMBER];
} Vector;
void PrintVector(char *msg, Vector * vector);
int vector_cmp_internal(Vector * a, Vector * b);
/*
* Allocate and initialize a new vector
*/
static inline Vector *
InitVector(int dim)
{
Vector *result;
int size;
size = VECTOR_SIZE(dim);
result = (Vector *) palloc0(size);
SET_VARSIZE(result, size);
result->dim = dim;
return result;
}
#endif