mirror of
https://github.com/pgvector/pgvector.git
synced 2026-07-03 11:10:56 +08:00
Added src directory
This commit is contained in:
503
src/ivfbuild.c
Normal file
503
src/ivfbuild.c
Normal file
@@ -0,0 +1,503 @@
|
||||
#include "postgres.h"
|
||||
|
||||
#include <float.h>
|
||||
|
||||
#include "catalog/index.h"
|
||||
#include "ivfflat.h"
|
||||
#include "miscadmin.h"
|
||||
#include "storage/bufmgr.h"
|
||||
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
#include "access/tableam.h"
|
||||
#endif
|
||||
|
||||
#if PG_VERSION_NUM >= 110000
|
||||
#include "catalog/pg_operator_d.h"
|
||||
#include "catalog/pg_type_d.h"
|
||||
#else
|
||||
#include "catalog/pg_operator.h"
|
||||
#include "catalog/pg_type.h"
|
||||
#endif
|
||||
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
#define CALLBACK_ITEM_POINTER ItemPointer tid
|
||||
#else
|
||||
#define CALLBACK_ITEM_POINTER HeapTuple hup
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Callback for sampling
|
||||
*/
|
||||
static void
|
||||
SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
|
||||
bool *isnull, bool tupleIsAlive, void *state)
|
||||
{
|
||||
IvfflatBuildState *buildstate = (IvfflatBuildState *) state;
|
||||
VectorArray samples = buildstate->samples;
|
||||
int targsamples = samples->maxlen;
|
||||
Datum value = values[0];
|
||||
|
||||
/* Skip nulls */
|
||||
if (isnull[0])
|
||||
return;
|
||||
|
||||
/* Normalize the value */
|
||||
if (buildstate->normprocinfo != NULL)
|
||||
{
|
||||
if (!IvfflatNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->normvec))
|
||||
return;
|
||||
}
|
||||
|
||||
if (samples->length < targsamples)
|
||||
{
|
||||
VectorArraySet(samples, samples->length, DatumGetVector(value));
|
||||
samples->length++;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (buildstate->rowstoskip < 0)
|
||||
buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, samples->length, targsamples);
|
||||
|
||||
if (buildstate->rowstoskip <= 0)
|
||||
{
|
||||
int k = (int) (targsamples * sampler_random_fract(buildstate->rstate.randstate));
|
||||
|
||||
Assert(k >= 0 && k < targsamples);
|
||||
VectorArraySet(samples, k, DatumGetVector(value));
|
||||
}
|
||||
|
||||
buildstate->rowstoskip -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Sample rows with same logic as ANALYZE
|
||||
*/
|
||||
static void
|
||||
SampleRows(IvfflatBuildState * buildstate)
|
||||
{
|
||||
int targsamples = buildstate->samples->maxlen;
|
||||
BlockNumber totalblocks = RelationGetNumberOfBlocks(buildstate->heap);
|
||||
|
||||
buildstate->rowstoskip = -1;
|
||||
|
||||
BlockSampler_Init(&buildstate->bs, totalblocks, targsamples, random());
|
||||
|
||||
reservoir_init_selection_state(&buildstate->rstate, targsamples);
|
||||
while (BlockSampler_HasMore(&buildstate->bs))
|
||||
{
|
||||
BlockNumber targblock = BlockSampler_Next(&buildstate->bs);
|
||||
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
table_index_build_range_scan(buildstate->heap, buildstate->index, buildstate->indexInfo,
|
||||
false, true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL);
|
||||
#elif PG_VERSION_NUM >= 110000
|
||||
IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
|
||||
true, true, targblock, 1, SampleCallback, (void *) buildstate, NULL);
|
||||
#else
|
||||
IndexBuildHeapRangeScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
|
||||
true, true, targblock, 1, SampleCallback, (void *) buildstate);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Callback for table_index_build_scan
|
||||
*/
|
||||
static void
|
||||
BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
|
||||
bool *isnull, bool tupleIsAlive, void *state)
|
||||
{
|
||||
IvfflatBuildState *buildstate = (IvfflatBuildState *) state;
|
||||
double distance;
|
||||
double minDistance = DBL_MAX;
|
||||
int closestCenter = -1;
|
||||
VectorArray centers = buildstate->centers;
|
||||
TupleTableSlot *slot = buildstate->slot;
|
||||
Datum value = values[0];
|
||||
int i;
|
||||
|
||||
#if PG_VERSION_NUM < 130000
|
||||
ItemPointer tid = &hup->t_self;
|
||||
#endif
|
||||
|
||||
if (isnull[0])
|
||||
return;
|
||||
|
||||
/* Normalize if needed */
|
||||
if (buildstate->normprocinfo != NULL)
|
||||
{
|
||||
if (!IvfflatNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->normvec))
|
||||
return;
|
||||
}
|
||||
|
||||
/* Find the list that minimizes the distance */
|
||||
for (i = 0; i < centers->length; i++)
|
||||
{
|
||||
distance = DatumGetFloat8(FunctionCall2Coll(buildstate->procinfo, buildstate->collation, value, PointerGetDatum(VectorArrayGet(centers, i))));
|
||||
|
||||
if (distance < minDistance)
|
||||
{
|
||||
minDistance = distance;
|
||||
closestCenter = i;
|
||||
}
|
||||
}
|
||||
|
||||
/* Create a virtual tuple */
|
||||
ExecClearTuple(slot);
|
||||
slot->tts_values[0] = Int32GetDatum(closestCenter);
|
||||
slot->tts_isnull[0] = false;
|
||||
slot->tts_values[1] = Int32GetDatum(ItemPointerGetBlockNumberNoCheck(tid));
|
||||
slot->tts_isnull[1] = false;
|
||||
slot->tts_values[2] = Int32GetDatum(ItemPointerGetOffsetNumberNoCheck(tid));
|
||||
slot->tts_isnull[2] = false;
|
||||
slot->tts_values[3] = value;
|
||||
slot->tts_isnull[3] = false;
|
||||
ExecStoreVirtualTuple(slot);
|
||||
|
||||
/*
|
||||
* Add tuple to sort
|
||||
*
|
||||
* tuplesort_puttupleslot comment: Input data is always copied; the caller
|
||||
* need not save it.
|
||||
*/
|
||||
tuplesort_puttupleslot(buildstate->sortstate, slot);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get index tuple from sort state
|
||||
*/
|
||||
static inline void
|
||||
GetNextTuple(Tuplesortstate *sortstate, TupleDesc tupdesc, TupleTableSlot *slot, IndexTuple *itup, int *list)
|
||||
{
|
||||
Datum value;
|
||||
bool isnull;
|
||||
int tupblk;
|
||||
int tupoff;
|
||||
|
||||
#if PG_VERSION_NUM >= 100000
|
||||
if (tuplesort_gettupleslot(sortstate, true, false, slot, NULL))
|
||||
#else
|
||||
if (tuplesort_gettupleslot(sortstate, true, slot, NULL))
|
||||
#endif
|
||||
{
|
||||
*list = DatumGetInt32(slot_getattr(slot, 1, &isnull));
|
||||
tupblk = DatumGetInt32(slot_getattr(slot, 2, &isnull));
|
||||
tupoff = DatumGetInt32(slot_getattr(slot, 3, &isnull));
|
||||
value = slot_getattr(slot, 4, &isnull);
|
||||
|
||||
/* Form the index tuple */
|
||||
*itup = index_form_tuple(tupdesc, &value, &isnull);
|
||||
ItemPointerSet(&(*itup)->t_tid, tupblk, tupoff);
|
||||
}
|
||||
else
|
||||
*list = -1;
|
||||
}
|
||||
|
||||
/*
|
||||
* Create initial entry pages
|
||||
*/
|
||||
static void
|
||||
InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum)
|
||||
{
|
||||
Buffer buf;
|
||||
Page page;
|
||||
GenericXLogState *state;
|
||||
int list;
|
||||
IndexTuple itup = NULL; /* silence compiler warning */
|
||||
BlockNumber startPage = InvalidBlockNumber;
|
||||
BlockNumber insertPage = InvalidBlockNumber;
|
||||
Size itemsz;
|
||||
int i;
|
||||
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
TupleTableSlot *slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsMinimalTuple);
|
||||
#else
|
||||
TupleTableSlot *slot = MakeSingleTupleTableSlot(buildstate->tupdesc);
|
||||
#endif
|
||||
TupleDesc tupdesc = RelationGetDescr(index);
|
||||
|
||||
GetNextTuple(buildstate->sortstate, tupdesc, slot, &itup, &list);
|
||||
|
||||
for (i = 0; i < buildstate->centers->length; i++)
|
||||
{
|
||||
buf = IvfflatNewBuffer(index, forkNum);
|
||||
IvfflatInitPage(index, &buf, &page, &state);
|
||||
|
||||
startPage = BufferGetBlockNumber(buf);
|
||||
|
||||
/* Get all tuples for list */
|
||||
while (list == i)
|
||||
{
|
||||
/* Check for free space */
|
||||
itemsz = MAXALIGN(IndexTupleSize(itup));
|
||||
if (PageGetFreeSpace(page) < itemsz)
|
||||
IvfflatAppendPage(index, &buf, &page, &state, forkNum);
|
||||
|
||||
/* Add the item */
|
||||
if (PageAddItem(page, (Item) itup, itemsz, InvalidOffsetNumber, false, false) == InvalidOffsetNumber)
|
||||
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
|
||||
|
||||
buildstate->indtuples += 1;
|
||||
|
||||
GetNextTuple(buildstate->sortstate, tupdesc, slot, &itup, &list);
|
||||
}
|
||||
|
||||
insertPage = BufferGetBlockNumber(buf);
|
||||
|
||||
IvfflatCommitBuffer(buf, state);
|
||||
|
||||
/* Set the start and insert pages */
|
||||
IvfflatUpdateList(index, state, buildstate->listInfo[i], insertPage, startPage, forkNum);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Initialize the build state
|
||||
*/
|
||||
static void
|
||||
InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo)
|
||||
{
|
||||
buildstate->heap = heap;
|
||||
buildstate->index = index;
|
||||
buildstate->indexInfo = indexInfo;
|
||||
|
||||
buildstate->lists = IvfflatGetLists(index);
|
||||
buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod;
|
||||
|
||||
/* Require column to have dimensions to be indexed */
|
||||
if (buildstate->dimensions < 0)
|
||||
elog(ERROR, "column does not have dimensions");
|
||||
|
||||
buildstate->reltuples = 0;
|
||||
buildstate->indtuples = 0;
|
||||
|
||||
/* Get support functions */
|
||||
buildstate->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
|
||||
buildstate->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
|
||||
buildstate->collation = index->rd_indcollation[0];
|
||||
|
||||
/* Create tuple description for sorting */
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
buildstate->tupdesc = CreateTemplateTupleDesc(4);
|
||||
#else
|
||||
buildstate->tupdesc = CreateTemplateTupleDesc(4, false);
|
||||
#endif
|
||||
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 1, "list", INT4OID, -1, 0);
|
||||
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 2, "blkno", INT4OID, -1, 0);
|
||||
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 3, "offset", INT4OID, -1, 0);
|
||||
#if PG_VERSION_NUM >= 110000
|
||||
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 4, "vector", RelationGetDescr(index)->attrs[0].atttypid, -1, 0);
|
||||
#else
|
||||
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 4, "vector", RelationGetDescr(index)->attrs[0]->atttypid, -1, 0);
|
||||
#endif
|
||||
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual);
|
||||
#else
|
||||
buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc);
|
||||
#endif
|
||||
|
||||
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions);
|
||||
buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists);
|
||||
|
||||
/* Reuse for each tuple */
|
||||
buildstate->normvec = InitVector(buildstate->dimensions);
|
||||
}
|
||||
|
||||
/*
|
||||
* Free resources
|
||||
*/
|
||||
static void
|
||||
FreeBuildState(IvfflatBuildState * buildstate)
|
||||
{
|
||||
pfree(buildstate->centers);
|
||||
pfree(buildstate->listInfo);
|
||||
pfree(buildstate->normvec);
|
||||
}
|
||||
|
||||
/*
|
||||
* Compute centers
|
||||
*/
|
||||
static void
|
||||
ComputeCenters(IvfflatBuildState * buildstate)
|
||||
{
|
||||
int numSamples;
|
||||
|
||||
/* Target 50 samples per list, with at least 10000 samples */
|
||||
/* The number of samples has a large effect on index build time */
|
||||
numSamples = buildstate->lists * 50;
|
||||
if (numSamples < 10000)
|
||||
numSamples = 10000;
|
||||
|
||||
/* Sample samples */
|
||||
buildstate->samples = VectorArrayInit(numSamples, buildstate->dimensions);
|
||||
if (buildstate->heap != NULL)
|
||||
SampleRows(buildstate);
|
||||
|
||||
/* Calculate centers */
|
||||
IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers);
|
||||
|
||||
/* Free samples before we allocate more memory */
|
||||
pfree(buildstate->samples);
|
||||
}
|
||||
|
||||
/*
|
||||
* Create the metapage
|
||||
*/
|
||||
static void
|
||||
CreateMetaPage(Relation index, int dimensions, int lists, ForkNumber forkNum)
|
||||
{
|
||||
Buffer buf;
|
||||
Page page;
|
||||
GenericXLogState *state;
|
||||
IvfflatMetaPage metap;
|
||||
|
||||
buf = IvfflatNewBuffer(index, forkNum);
|
||||
IvfflatInitPage(index, &buf, &page, &state);
|
||||
|
||||
/* Set metapage data */
|
||||
metap = IvfflatPageGetMeta(page);
|
||||
metap->magicNumber = IVFFLAT_MAGIC_NUMBER;
|
||||
metap->version = IVFFLAT_VERSION;
|
||||
metap->dimensions = dimensions;
|
||||
metap->lists = lists;
|
||||
((PageHeader) page)->pd_lower =
|
||||
((char *) metap + sizeof(IvfflatMetaPageData)) - (char *) page;
|
||||
|
||||
IvfflatCommitBuffer(buf, state);
|
||||
}
|
||||
|
||||
/*
|
||||
* Create list pages
|
||||
*/
|
||||
static void
|
||||
CreateListPages(Relation index, VectorArray centers, int dimensions,
|
||||
int lists, ForkNumber forkNum, ListInfo * *listInfo)
|
||||
{
|
||||
int i;
|
||||
Buffer buf;
|
||||
Page page;
|
||||
GenericXLogState *state;
|
||||
OffsetNumber offno;
|
||||
Size itemsz;
|
||||
IvfflatList list;
|
||||
|
||||
itemsz = MAXALIGN(IVFFLAT_LIST_SIZE(dimensions));
|
||||
list = palloc(itemsz);
|
||||
|
||||
buf = IvfflatNewBuffer(index, forkNum);
|
||||
IvfflatInitPage(index, &buf, &page, &state);
|
||||
|
||||
for (i = 0; i < lists; i++)
|
||||
{
|
||||
/* Load list */
|
||||
list->startPage = InvalidBlockNumber;
|
||||
list->insertPage = InvalidBlockNumber;
|
||||
memcpy(&list->center, VectorArrayGet(centers, i), VECTOR_SIZE(dimensions));
|
||||
|
||||
/* Ensure free space */
|
||||
if (PageGetFreeSpace(page) < itemsz)
|
||||
IvfflatAppendPage(index, &buf, &page, &state, forkNum);
|
||||
|
||||
/* Add the item */
|
||||
offno = PageAddItem(page, (Item) list, itemsz, InvalidOffsetNumber, false, false);
|
||||
if (offno == InvalidOffsetNumber)
|
||||
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index));
|
||||
|
||||
/* Save location info */
|
||||
(*listInfo)[i].blkno = BufferGetBlockNumber(buf);
|
||||
(*listInfo)[i].offno = offno;
|
||||
}
|
||||
|
||||
IvfflatCommitBuffer(buf, state);
|
||||
|
||||
pfree(list);
|
||||
}
|
||||
|
||||
/*
|
||||
* Create entry pages
|
||||
*/
|
||||
static void
|
||||
CreateEntryPages(IvfflatBuildState * buildstate, ForkNumber forkNum)
|
||||
{
|
||||
AttrNumber attNums[] = {1};
|
||||
Oid sortOperators[] = {Float8LessOperator};
|
||||
Oid sortCollations[] = {InvalidOid};
|
||||
bool nullsFirstFlags[] = {false};
|
||||
|
||||
#if PG_VERSION_NUM >= 110000
|
||||
buildstate->sortstate = tuplesort_begin_heap(buildstate->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, maintenance_work_mem, NULL, false);
|
||||
#else
|
||||
buildstate->sortstate = tuplesort_begin_heap(buildstate->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, maintenance_work_mem, false);
|
||||
#endif
|
||||
|
||||
/* Add tuples to sort */
|
||||
if (buildstate->heap != NULL)
|
||||
{
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo,
|
||||
true, true, BuildCallback, (void *) buildstate, NULL);
|
||||
#elif PG_VERSION_NUM >= 110000
|
||||
buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
|
||||
true, BuildCallback, (void *) buildstate, NULL);
|
||||
#else
|
||||
buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
|
||||
true, BuildCallback, (void *) buildstate);
|
||||
#endif
|
||||
}
|
||||
|
||||
/* Sort and insert */
|
||||
tuplesort_performsort(buildstate->sortstate);
|
||||
InsertTuples(buildstate->index, buildstate, forkNum);
|
||||
tuplesort_end(buildstate->sortstate);
|
||||
}
|
||||
|
||||
/*
|
||||
* Build the index
|
||||
*/
|
||||
static void
|
||||
BuildIndex(Relation heap, Relation index, IndexInfo *indexInfo,
|
||||
IvfflatBuildState * buildstate, ForkNumber forkNum)
|
||||
{
|
||||
InitBuildState(buildstate, heap, index, indexInfo);
|
||||
|
||||
ComputeCenters(buildstate);
|
||||
|
||||
/* Create pages */
|
||||
CreateMetaPage(index, buildstate->dimensions, buildstate->lists, forkNum);
|
||||
CreateListPages(index, buildstate->centers, buildstate->dimensions, buildstate->lists, forkNum, &buildstate->listInfo);
|
||||
CreateEntryPages(buildstate, forkNum);
|
||||
|
||||
FreeBuildState(buildstate);
|
||||
}
|
||||
|
||||
/*
|
||||
* Build the index for a logged table
|
||||
*/
|
||||
IndexBuildResult *
|
||||
ivfflatbuild(Relation heap, Relation index, IndexInfo *indexInfo)
|
||||
{
|
||||
IndexBuildResult *result;
|
||||
IvfflatBuildState buildstate;
|
||||
|
||||
BuildIndex(heap, index, indexInfo, &buildstate, MAIN_FORKNUM);
|
||||
|
||||
result = (IndexBuildResult *) palloc(sizeof(IndexBuildResult));
|
||||
result->heap_tuples = buildstate.reltuples;
|
||||
result->index_tuples = buildstate.indtuples;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Build the index for an unlogged table
|
||||
*/
|
||||
void
|
||||
ivfflatbuildempty(Relation index)
|
||||
{
|
||||
IndexInfo *indexInfo = BuildIndexInfo(index);
|
||||
IvfflatBuildState buildstate;
|
||||
|
||||
BuildIndex(NULL, index, indexInfo, &buildstate, INIT_FORKNUM);
|
||||
}
|
||||
168
src/ivfflat.c
Normal file
168
src/ivfflat.c
Normal file
@@ -0,0 +1,168 @@
|
||||
#include "postgres.h"
|
||||
|
||||
#include "access/amapi.h"
|
||||
#include "commands/vacuum.h"
|
||||
#include "ivfflat.h"
|
||||
#include "utils/guc.h"
|
||||
#include "utils/selfuncs.h"
|
||||
|
||||
static relopt_kind ivfflat_relopt_kind;
|
||||
|
||||
/*
|
||||
* Initialize index options and variables
|
||||
*/
|
||||
void
|
||||
_PG_init(void)
|
||||
{
|
||||
ivfflat_relopt_kind = add_reloption_kind();
|
||||
add_int_reloption(ivfflat_relopt_kind, "lists", "Number of inverted lists",
|
||||
IVFFLAT_DEFAULT_LISTS, 1, IVFFLAT_MAX_LISTS
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
,AccessExclusiveLock
|
||||
#endif
|
||||
);
|
||||
|
||||
DefineCustomIntVariable("ivfflat.probes", "Sets the number of probes",
|
||||
"Valid range is 1..lists.", &ivfflat_probes,
|
||||
1, 1, IVFFLAT_MAX_LISTS, PGC_USERSET, 0, NULL, NULL, NULL);
|
||||
}
|
||||
|
||||
/*
|
||||
* Estimate the cost of an index scan
|
||||
*
|
||||
* TODO Improve cost estimation
|
||||
*/
|
||||
static void
|
||||
ivfflatcostestimate(PlannerInfo *root, IndexPath *path, double loop_count,
|
||||
Cost *indexStartupCost, Cost *indexTotalCost,
|
||||
Selectivity *indexSelectivity, double *indexCorrelation
|
||||
#if PG_VERSION_NUM >= 100000
|
||||
,double *indexPages
|
||||
#endif
|
||||
)
|
||||
{
|
||||
GenericCosts costs;
|
||||
|
||||
#if PG_VERSION_NUM < 120000
|
||||
List *qinfos;
|
||||
|
||||
qinfos = deconstruct_indexquals(path);
|
||||
#endif
|
||||
|
||||
MemSet(&costs, 0, sizeof(costs));
|
||||
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
genericcostestimate(root, path, loop_count, &costs);
|
||||
#else
|
||||
genericcostestimate(root, path, loop_count, qinfos, &costs);
|
||||
#endif
|
||||
|
||||
*indexStartupCost = costs.indexStartupCost;
|
||||
*indexTotalCost = costs.indexTotalCost;
|
||||
*indexSelectivity = costs.indexSelectivity;
|
||||
*indexCorrelation = costs.indexCorrelation;
|
||||
#if PG_VERSION_NUM >= 100000
|
||||
*indexPages = costs.numIndexPages;
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Parse and validate the reloptions
|
||||
*/
|
||||
static bytea *
|
||||
ivfflatoptions(Datum reloptions, bool validate)
|
||||
{
|
||||
static const relopt_parse_elt tab[] = {
|
||||
{"lists", RELOPT_TYPE_INT, offsetof(IvfflatOptions, lists)},
|
||||
};
|
||||
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
return (bytea *) build_reloptions(reloptions, validate,
|
||||
ivfflat_relopt_kind,
|
||||
sizeof(IvfflatOptions),
|
||||
tab, lengthof(tab));
|
||||
#else
|
||||
relopt_value *options;
|
||||
int numoptions;
|
||||
IvfflatOptions *rdopts;
|
||||
|
||||
options = parseRelOptions(reloptions, validate, ivfflat_relopt_kind, &numoptions);
|
||||
rdopts = allocateReloptStruct(sizeof(IvfflatOptions), options, numoptions);
|
||||
fillRelOptions((void *) rdopts, sizeof(IvfflatOptions), options, numoptions,
|
||||
validate, tab, lengthof(tab));
|
||||
|
||||
return (bytea *) rdopts;
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Validate catalog entries for the specified operator class
|
||||
*/
|
||||
static bool
|
||||
ivfflatvalidate(Oid opclassoid)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
PG_FUNCTION_INFO_V1(ivfflathandler);
|
||||
Datum
|
||||
ivfflathandler(PG_FUNCTION_ARGS)
|
||||
{
|
||||
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
|
||||
|
||||
amroutine->amstrategies = 0;
|
||||
amroutine->amsupport = 4;
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
amroutine->amoptsprocnum = 0;
|
||||
#endif
|
||||
amroutine->amcanorder = false;
|
||||
amroutine->amcanorderbyop = true;
|
||||
amroutine->amcanbackward = false; /* can change direction mid-scan */
|
||||
amroutine->amcanunique = false;
|
||||
amroutine->amcanmulticol = false;
|
||||
amroutine->amoptionalkey = true;
|
||||
amroutine->amsearcharray = false;
|
||||
amroutine->amsearchnulls = false;
|
||||
amroutine->amstorage = false;
|
||||
amroutine->amclusterable = false;
|
||||
amroutine->ampredlocks = false;
|
||||
#if PG_VERSION_NUM >= 100000
|
||||
amroutine->amcanparallel = false;
|
||||
#endif
|
||||
#if PG_VERSION_NUM >= 110000
|
||||
amroutine->amcaninclude = false;
|
||||
#endif
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
amroutine->amusemaintenanceworkmem = false; /* not used during VACUUM */
|
||||
amroutine->amparallelvacuumoptions = VACUUM_OPTION_NO_PARALLEL; /* TODO support parallel */
|
||||
#endif
|
||||
amroutine->amkeytype = InvalidOid;
|
||||
|
||||
amroutine->ambuild = ivfflatbuild;
|
||||
amroutine->ambuildempty = ivfflatbuildempty;
|
||||
amroutine->aminsert = ivfflatinsert;
|
||||
amroutine->ambulkdelete = ivfflatbulkdelete;
|
||||
amroutine->amvacuumcleanup = ivfflatvacuumcleanup;
|
||||
amroutine->amcanreturn = NULL;
|
||||
amroutine->amcostestimate = ivfflatcostestimate;
|
||||
amroutine->amoptions = ivfflatoptions;
|
||||
amroutine->amproperty = NULL; /* TODO AMPROP_DISTANCE_ORDERABLE */
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
amroutine->ambuildphasename = NULL;
|
||||
#endif
|
||||
amroutine->amvalidate = ivfflatvalidate;
|
||||
amroutine->ambeginscan = ivfflatbeginscan;
|
||||
amroutine->amrescan = ivfflatrescan;
|
||||
amroutine->amgettuple = ivfflatgettuple;
|
||||
amroutine->amgetbitmap = NULL;
|
||||
amroutine->amendscan = ivfflatendscan;
|
||||
amroutine->ammarkpos = NULL;
|
||||
amroutine->amrestrpos = NULL;
|
||||
#if PG_VERSION_NUM >= 100000
|
||||
amroutine->amestimateparallelscan = NULL;
|
||||
amroutine->aminitparallelscan = NULL;
|
||||
amroutine->amparallelrescan = NULL;
|
||||
#endif
|
||||
|
||||
PG_RETURN_POINTER(amroutine);
|
||||
}
|
||||
193
src/ivfflat.h
Normal file
193
src/ivfflat.h
Normal file
@@ -0,0 +1,193 @@
|
||||
#ifndef IVFFLAT_H
|
||||
#define IVFFLAT_H
|
||||
|
||||
#include "postgres.h"
|
||||
|
||||
#include "access/generic_xlog.h"
|
||||
#include "access/reloptions.h"
|
||||
#include "nodes/execnodes.h"
|
||||
#include "utils/sampling.h"
|
||||
#include "utils/tuplesort.h"
|
||||
#include "vector.h"
|
||||
|
||||
/* Support functions */
|
||||
#define IVFFLAT_DISTANCE_PROC 1
|
||||
#define IVFFLAT_NORM_PROC 2
|
||||
#define IVFFLAT_KMEANS_DISTANCE_PROC 3
|
||||
#define IVFFLAT_KMEANS_NORM_PROC 4
|
||||
|
||||
#define IVFFLAT_VERSION 1
|
||||
#define IVFFLAT_MAGIC_NUMBER 0x14FF1A7
|
||||
#define IVFFLAT_PAGE_ID 0xFF84
|
||||
|
||||
/* Preserved page numbers */
|
||||
#define IVFFLAT_METAPAGE_BLKNO 0
|
||||
#define IVFFLAT_HEAD_BLKNO 1 /* first list page */
|
||||
|
||||
#define IVFFLAT_DEFAULT_LISTS 100
|
||||
#define IVFFLAT_MAX_LISTS 32768
|
||||
|
||||
#define IVFFLAT_LIST_SIZE(_dim) (offsetof(IvfflatListData, center) + VECTOR_SIZE(_dim))
|
||||
|
||||
#define IvfflatPageGetOpaque(page) ((IvfflatPageOpaque) PageGetSpecialPointer(page))
|
||||
#define IvfflatPageGetMeta(page) ((IvfflatMetaPageData *) PageGetContents(page))
|
||||
|
||||
#if PG_VERSION_NUM < 100000
|
||||
#define ItemPointerGetBlockNumberNoCheck ItemPointerGetBlockNumber
|
||||
#define ItemPointerGetOffsetNumberNoCheck ItemPointerGetOffsetNumber
|
||||
#endif
|
||||
|
||||
/* Variables */
|
||||
int ivfflat_probes;
|
||||
|
||||
typedef struct VectorArrayData
|
||||
{
|
||||
int length;
|
||||
int maxlen;
|
||||
int dim;
|
||||
Vector items[FLEXIBLE_ARRAY_MEMBER];
|
||||
} VectorArrayData;
|
||||
|
||||
typedef VectorArrayData * VectorArray;
|
||||
|
||||
typedef struct ListInfo
|
||||
{
|
||||
BlockNumber blkno;
|
||||
OffsetNumber offno;
|
||||
} ListInfo;
|
||||
|
||||
/* IVFFlat index options */
|
||||
typedef struct IvfflatOptions
|
||||
{
|
||||
int32 vl_len_; /* varlena header (do not touch directly!) */
|
||||
int lists; /* number of lists */
|
||||
} IvfflatOptions;
|
||||
|
||||
typedef struct IvfflatBuildState
|
||||
{
|
||||
/* Info */
|
||||
Relation heap;
|
||||
Relation index;
|
||||
IndexInfo *indexInfo;
|
||||
|
||||
/* Settings */
|
||||
int dimensions;
|
||||
int lists;
|
||||
|
||||
/* Statistics */
|
||||
double indtuples;
|
||||
double reltuples;
|
||||
|
||||
/* Support functions */
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
Oid collation;
|
||||
|
||||
/* Variables */
|
||||
VectorArray samples;
|
||||
VectorArray centers;
|
||||
ListInfo *listInfo;
|
||||
Vector *normvec;
|
||||
|
||||
/* Sampling */
|
||||
BlockSamplerData bs;
|
||||
ReservoirStateData rstate;
|
||||
int rowstoskip;
|
||||
|
||||
/* Sorting */
|
||||
Tuplesortstate *sortstate;
|
||||
TupleDesc tupdesc;
|
||||
TupleTableSlot *slot;
|
||||
} IvfflatBuildState;
|
||||
|
||||
typedef struct IvfflatMetaPageData
|
||||
{
|
||||
uint32 magicNumber;
|
||||
uint32 version;
|
||||
uint16 dimensions;
|
||||
uint16 lists;
|
||||
} IvfflatMetaPageData;
|
||||
|
||||
typedef IvfflatMetaPageData * IvfflatMetaPage;
|
||||
|
||||
typedef struct IvfflatPageOpaqueData
|
||||
{
|
||||
BlockNumber nextblkno;
|
||||
uint16 unused;
|
||||
uint16 page_id; /* for identification of IVFFlat indexes */
|
||||
} IvfflatPageOpaqueData;
|
||||
|
||||
typedef IvfflatPageOpaqueData * IvfflatPageOpaque;
|
||||
|
||||
typedef struct IvfflatListData
|
||||
{
|
||||
BlockNumber startPage;
|
||||
BlockNumber insertPage;
|
||||
Vector center;
|
||||
} IvfflatListData;
|
||||
|
||||
typedef IvfflatListData * IvfflatList;
|
||||
|
||||
typedef struct IvfflatScanList
|
||||
{
|
||||
BlockNumber startPage;
|
||||
double distance;
|
||||
} IvfflatScanList;
|
||||
|
||||
typedef struct IvfflatScanOpaqueData
|
||||
{
|
||||
int probes;
|
||||
bool first;
|
||||
Buffer buf;
|
||||
|
||||
/* Sorting */
|
||||
Tuplesortstate *sortstate;
|
||||
TupleDesc tupdesc;
|
||||
TupleTableSlot *slot;
|
||||
bool isnull;
|
||||
|
||||
/* Support functions */
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
Oid collation;
|
||||
|
||||
IvfflatScanList lists[FLEXIBLE_ARRAY_MEMBER]; /* must come last */
|
||||
} IvfflatScanOpaqueData;
|
||||
|
||||
typedef IvfflatScanOpaqueData * IvfflatScanOpaque;
|
||||
|
||||
#define VECTOR_ARRAY_SIZE(_length, _dim) (offsetof(VectorArrayData, items) + _length * VECTOR_SIZE(_dim))
|
||||
#define VECTOR_ARRAY_OFFSET(_arr, _offset) ((char*) _arr + offsetof(VectorArrayData, items) + (_offset) * VECTOR_SIZE(_arr->dim))
|
||||
#define VectorArrayGet(_arr, _offset) ((Vector *) VECTOR_ARRAY_OFFSET(_arr, _offset))
|
||||
#define VectorArraySet(_arr, _offset, _val) (memcpy(VECTOR_ARRAY_OFFSET(_arr, _offset), _val, VECTOR_SIZE(_arr->dim)))
|
||||
|
||||
/* Methods */
|
||||
void _PG_init(void);
|
||||
VectorArray VectorArrayInit(int maxlen, int dimensions);
|
||||
void PrintVectorArray(char *msg, VectorArray arr);
|
||||
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers);
|
||||
FmgrInfo *IvfflatOptionalProcInfo(Relation rel, uint16 procnum);
|
||||
bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result);
|
||||
int IvfflatGetLists(Relation index);
|
||||
void IvfflatUpdateList(Relation index, GenericXLogState *state, ListInfo listInfo, BlockNumber insertPage, BlockNumber startPage, ForkNumber forkNum);
|
||||
void IvfflatCommitBuffer(Buffer buf, GenericXLogState *state);
|
||||
void IvfflatAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, ForkNumber forkNum);
|
||||
Buffer IvfflatNewBuffer(Relation index, ForkNumber forkNum);
|
||||
void IvfflatInitPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state);
|
||||
|
||||
/* Index access methods */
|
||||
IndexBuildResult *ivfflatbuild(Relation heap, Relation index, IndexInfo *indexInfo);
|
||||
void ivfflatbuildempty(Relation index);
|
||||
bool ivfflatinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heap, IndexUniqueCheck checkUnique
|
||||
#if PG_VERSION_NUM >= 100000
|
||||
,IndexInfo *indexInfo
|
||||
#endif
|
||||
);
|
||||
IndexBulkDeleteResult *ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state);
|
||||
IndexBulkDeleteResult *ivfflatvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats);
|
||||
IndexScanDesc ivfflatbeginscan(Relation index, int nkeys, int norderbys);
|
||||
void ivfflatrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys);
|
||||
bool ivfflatgettuple(IndexScanDesc scan, ScanDirection dir);
|
||||
void ivfflatendscan(IndexScanDesc scan);
|
||||
|
||||
#endif
|
||||
163
src/ivfinsert.c
Normal file
163
src/ivfinsert.c
Normal file
@@ -0,0 +1,163 @@
|
||||
#include "postgres.h"
|
||||
|
||||
#include <float.h>
|
||||
|
||||
#include "ivfflat.h"
|
||||
#include "storage/bufmgr.h"
|
||||
|
||||
/*
|
||||
* Find the list that minimizes the distance function
|
||||
*/
|
||||
static void
|
||||
FindInsertPage(Relation rel, Datum *values, BlockNumber *insertPage, ListInfo * listInfo)
|
||||
{
|
||||
Buffer cbuf;
|
||||
Page cpage;
|
||||
IvfflatList list;
|
||||
double distance;
|
||||
double minDistance = DBL_MAX;
|
||||
BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO;
|
||||
FmgrInfo *procinfo;
|
||||
Oid collation;
|
||||
OffsetNumber offno;
|
||||
OffsetNumber maxoffno;
|
||||
|
||||
procinfo = index_getprocinfo(rel, 1, IVFFLAT_DISTANCE_PROC);
|
||||
collation = rel->rd_indcollation[0];
|
||||
|
||||
/* Search all list pages */
|
||||
while (BlockNumberIsValid(nextblkno))
|
||||
{
|
||||
cbuf = ReadBuffer(rel, nextblkno);
|
||||
LockBuffer(cbuf, BUFFER_LOCK_SHARE);
|
||||
cpage = BufferGetPage(cbuf);
|
||||
maxoffno = PageGetMaxOffsetNumber(cpage);
|
||||
|
||||
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
|
||||
{
|
||||
list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, offno));
|
||||
distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, values[0], PointerGetDatum(&list->center)));
|
||||
|
||||
if (distance < minDistance)
|
||||
{
|
||||
*insertPage = list->insertPage;
|
||||
listInfo->blkno = nextblkno;
|
||||
listInfo->offno = offno;
|
||||
minDistance = distance;
|
||||
}
|
||||
}
|
||||
|
||||
nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno;
|
||||
|
||||
UnlockReleaseBuffer(cbuf);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Prepare to insert an index tuple
|
||||
*/
|
||||
static void
|
||||
LoadInsertPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, BlockNumber insertPage)
|
||||
{
|
||||
*buf = ReadBuffer(index, insertPage);
|
||||
LockBuffer(*buf, BUFFER_LOCK_EXCLUSIVE);
|
||||
*state = GenericXLogStart(index);
|
||||
*page = GenericXLogRegisterBuffer(*state, *buf, 0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Insert a tuple into the index
|
||||
*/
|
||||
static void
|
||||
InsertTuple(Relation rel, IndexTuple itup, Relation heapRel, Datum *values)
|
||||
{
|
||||
Buffer buf;
|
||||
Page page;
|
||||
GenericXLogState *state;
|
||||
Size itemsz;
|
||||
BlockNumber insertPage = InvalidBlockNumber;
|
||||
ListInfo listInfo;
|
||||
bool newPage = false;
|
||||
|
||||
/* Find the insert page - sets the page and list info */
|
||||
FindInsertPage(rel, values, &insertPage, &listInfo);
|
||||
Assert(BlockNumberIsValid(insertPage));
|
||||
|
||||
itemsz = MAXALIGN(IndexTupleSize(itup));
|
||||
Assert(itemsz <= BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(IvfflatPageOpaqueData)));
|
||||
|
||||
LoadInsertPage(rel, &buf, &page, &state, insertPage);
|
||||
|
||||
/* Find a page to insert the item */
|
||||
while (PageGetFreeSpace(page) < itemsz)
|
||||
{
|
||||
insertPage = IvfflatPageGetOpaque(page)->nextblkno;
|
||||
|
||||
if (BlockNumberIsValid(insertPage))
|
||||
{
|
||||
/* Move to next page */
|
||||
GenericXLogAbort(state);
|
||||
UnlockReleaseBuffer(buf);
|
||||
|
||||
LoadInsertPage(rel, &buf, &page, &state, insertPage);
|
||||
}
|
||||
else
|
||||
{
|
||||
/* Add a new page */
|
||||
IvfflatAppendPage(rel, &buf, &page, &state, MAIN_FORKNUM);
|
||||
|
||||
insertPage = BufferGetBlockNumber(buf);
|
||||
newPage = true;
|
||||
}
|
||||
}
|
||||
|
||||
/* Add to next offset */
|
||||
if (PageAddItem(page, (Item) itup, itemsz, InvalidOffsetNumber, false, false) == InvalidOffsetNumber)
|
||||
elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(rel));
|
||||
|
||||
IvfflatCommitBuffer(buf, state);
|
||||
|
||||
/* Update the insert page */
|
||||
if (newPage)
|
||||
IvfflatUpdateList(rel, state, listInfo, insertPage, InvalidBlockNumber, MAIN_FORKNUM);
|
||||
}
|
||||
|
||||
/*
|
||||
* Insert a tuple into the index
|
||||
*/
|
||||
bool
|
||||
ivfflatinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid,
|
||||
Relation heap, IndexUniqueCheck checkUnique
|
||||
#if PG_VERSION_NUM >= 100000
|
||||
,IndexInfo *indexInfo
|
||||
#endif
|
||||
)
|
||||
{
|
||||
IndexTuple itup;
|
||||
Datum value;
|
||||
FmgrInfo *normprocinfo;
|
||||
|
||||
if (isnull[0])
|
||||
return false;
|
||||
|
||||
value = values[0];
|
||||
|
||||
/* Normalize if needed */
|
||||
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
|
||||
if (normprocinfo != NULL)
|
||||
{
|
||||
if (!IvfflatNormValue(normprocinfo, index->rd_indcollation[0], &value, NULL))
|
||||
return false;
|
||||
}
|
||||
|
||||
itup = index_form_tuple(RelationGetDescr(index), &value, isnull);
|
||||
itup->t_tid = *heap_tid;
|
||||
InsertTuple(index, itup, heap, &value);
|
||||
pfree(itup);
|
||||
|
||||
/* Clean up if we allocated a new value */
|
||||
if (value != values[0])
|
||||
pfree(DatumGetPointer(value));
|
||||
|
||||
return false;
|
||||
}
|
||||
479
src/ivfkmeans.c
Normal file
479
src/ivfkmeans.c
Normal file
@@ -0,0 +1,479 @@
|
||||
#include "postgres.h"
|
||||
|
||||
#include <float.h>
|
||||
|
||||
#include "ivfflat.h"
|
||||
#include "miscadmin.h"
|
||||
|
||||
/*
|
||||
* Initialize with kmeans++
|
||||
*
|
||||
* https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf
|
||||
*/
|
||||
static void
|
||||
InitCenters(Relation index, VectorArray samples, VectorArray centers, double *lowerBound)
|
||||
{
|
||||
FmgrInfo *procinfo;
|
||||
Oid collation;
|
||||
int i;
|
||||
int j;
|
||||
double distance;
|
||||
double sum;
|
||||
double choice;
|
||||
Vector *vec;
|
||||
double *weight = palloc(samples->length * sizeof(double));
|
||||
int numCenters = centers->maxlen;
|
||||
int numSamples = samples->length;
|
||||
|
||||
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
|
||||
collation = index->rd_indcollation[0];
|
||||
|
||||
/* Choose an initial center uniformly at random */
|
||||
VectorArraySet(centers, 0, VectorArrayGet(samples, random() % samples->length));
|
||||
centers->length++;
|
||||
|
||||
for (j = 0; j < numSamples; j++)
|
||||
weight[j] = DBL_MAX;
|
||||
|
||||
for (i = 0; i < numCenters; i++)
|
||||
{
|
||||
CHECK_FOR_INTERRUPTS();
|
||||
|
||||
sum = 0.0;
|
||||
|
||||
for (j = 0; j < numSamples; j++)
|
||||
{
|
||||
vec = VectorArrayGet(samples, j);
|
||||
|
||||
/* Only need to compute distance for new center */
|
||||
/* TODO Use triangle inequality to reduce distance calculations */
|
||||
distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, i))));
|
||||
|
||||
/* Set lower bound */
|
||||
lowerBound[j * numCenters + i] = distance;
|
||||
|
||||
/* Use distance squared for weighted probability distribution */
|
||||
distance *= distance;
|
||||
|
||||
if (distance < weight[j])
|
||||
weight[j] = distance;
|
||||
|
||||
sum += weight[j];
|
||||
}
|
||||
|
||||
/* Only compute lower bound on last iteration */
|
||||
if (i + 1 == numCenters)
|
||||
break;
|
||||
|
||||
/* Choose new center using weighted probability distribution. */
|
||||
choice = sum * (((double) random()) / MAX_RANDOM_VALUE);
|
||||
for (j = 0; j < numSamples - 1; j++)
|
||||
{
|
||||
choice -= weight[j];
|
||||
if (choice <= 0)
|
||||
break;
|
||||
}
|
||||
|
||||
VectorArraySet(centers, i + 1, VectorArrayGet(samples, j));
|
||||
centers->length++;
|
||||
}
|
||||
|
||||
pfree(weight);
|
||||
}
|
||||
|
||||
/*
|
||||
* Apply norm to vector
|
||||
*/
|
||||
static inline void
|
||||
ApplyNorm(FmgrInfo *normprocinfo, Oid collation, Vector * vec)
|
||||
{
|
||||
int i;
|
||||
double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(vec)));
|
||||
|
||||
/* TODO Handle zero norm */
|
||||
if (norm > 0)
|
||||
{
|
||||
for (i = 0; i < vec->dim; i++)
|
||||
vec->x[i] /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Compare vectors
|
||||
*/
|
||||
static int
|
||||
CompareVectors(const void *a, const void *b)
|
||||
{
|
||||
return vector_cmp_internal((Vector *) a, (Vector *) b);
|
||||
}
|
||||
|
||||
/*
|
||||
* Quick approach if we have little data
|
||||
*/
|
||||
static void
|
||||
QuickCenters(Relation index, VectorArray samples, VectorArray centers)
|
||||
{
|
||||
int i;
|
||||
int j;
|
||||
Vector *vec;
|
||||
int dimensions = centers->dim;
|
||||
Oid collation = index->rd_indcollation[0];
|
||||
FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
|
||||
|
||||
/* Copy existing vectors while avoiding duplicates */
|
||||
qsort(samples->items, samples->length, VECTOR_SIZE(samples->dim), CompareVectors);
|
||||
for (i = 0; i < samples->length; i++)
|
||||
{
|
||||
vec = VectorArrayGet(samples, i);
|
||||
|
||||
if (i == 0 || CompareVectors(vec, VectorArrayGet(samples, i - 1)) != 0)
|
||||
{
|
||||
VectorArraySet(centers, centers->length, vec);
|
||||
centers->length++;
|
||||
}
|
||||
}
|
||||
|
||||
/* Fill remaining with random data */
|
||||
while (centers->length < centers->maxlen)
|
||||
{
|
||||
vec = VectorArrayGet(centers, centers->length);
|
||||
|
||||
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
|
||||
vec->dim = dimensions;
|
||||
|
||||
for (j = 0; j < dimensions; j++)
|
||||
vec->x[j] = ((double) random()) / MAX_RANDOM_VALUE;
|
||||
|
||||
/* Normalize if needed (only needed for random centers) */
|
||||
if (normprocinfo != NULL)
|
||||
ApplyNorm(normprocinfo, collation, vec);
|
||||
|
||||
centers->length++;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Use Elkan for performance. This requires distance function to satisfy triangle inequality.
|
||||
*
|
||||
* We use L2 distance for L2 (not L2 squared like index scan)
|
||||
* and angular distance for inner product and cosine distance
|
||||
*
|
||||
* https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf
|
||||
*/
|
||||
static void
|
||||
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
|
||||
{
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
Oid collation;
|
||||
Vector *vec;
|
||||
Vector *newCenter;
|
||||
int iteration;
|
||||
int j;
|
||||
int k;
|
||||
int dimensions = centers->dim;
|
||||
int numCenters = centers->maxlen;
|
||||
int numSamples = samples->length;
|
||||
VectorArray newCenters;
|
||||
int *centerCounts;
|
||||
int *closestCenters;
|
||||
double *lowerBound;
|
||||
double *upperBound;
|
||||
double *s;
|
||||
double *halfcdist;
|
||||
double *newcdist;
|
||||
int changes;
|
||||
double minDistance;
|
||||
int closestCenter;
|
||||
double distance;
|
||||
bool rj;
|
||||
bool rjreset;
|
||||
double dxcx;
|
||||
double dxc;
|
||||
|
||||
/* Set support functions */
|
||||
procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC);
|
||||
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
|
||||
collation = index->rd_indcollation[0];
|
||||
|
||||
/* Allocate space */
|
||||
centerCounts = palloc(sizeof(int) * numCenters);
|
||||
closestCenters = palloc(sizeof(int) * numSamples);
|
||||
lowerBound = palloc(sizeof(double) * numSamples * numCenters);
|
||||
upperBound = palloc(sizeof(double) * numSamples);
|
||||
s = palloc(sizeof(double) * numCenters);
|
||||
halfcdist = palloc(sizeof(double) * numCenters * numCenters);
|
||||
newcdist = palloc(sizeof(double) * numCenters);
|
||||
|
||||
newCenters = VectorArrayInit(numCenters, dimensions);
|
||||
for (j = 0; j < numCenters; j++)
|
||||
{
|
||||
vec = VectorArrayGet(newCenters, j);
|
||||
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
|
||||
vec->dim = dimensions;
|
||||
}
|
||||
|
||||
/* Pick initial centers */
|
||||
InitCenters(index, samples, centers, lowerBound);
|
||||
|
||||
/* Assign each x to its closest initial center c(x) = argmin d(x,c) */
|
||||
for (j = 0; j < numSamples; j++)
|
||||
{
|
||||
minDistance = DBL_MAX;
|
||||
closestCenter = -1;
|
||||
|
||||
vec = VectorArrayGet(samples, j);
|
||||
|
||||
/* Find closest center */
|
||||
for (k = 0; k < numCenters; k++)
|
||||
{
|
||||
/* TODO Use Lemma 1 in k-means++ initialization */
|
||||
distance = lowerBound[j * numCenters + k];
|
||||
|
||||
if (distance < minDistance)
|
||||
{
|
||||
minDistance = distance;
|
||||
closestCenter = k;
|
||||
}
|
||||
}
|
||||
|
||||
upperBound[j] = minDistance;
|
||||
closestCenters[j] = closestCenter;
|
||||
}
|
||||
|
||||
/* Give 500 iterations to converge */
|
||||
for (iteration = 0; iteration < 500; iteration++)
|
||||
{
|
||||
/* Can take a while, so ensure we can interrupt */
|
||||
CHECK_FOR_INTERRUPTS();
|
||||
|
||||
changes = 0;
|
||||
|
||||
/* Step 1: For all centers, compute distance */
|
||||
for (j = 0; j < numCenters; j++)
|
||||
{
|
||||
vec = VectorArrayGet(centers, j);
|
||||
|
||||
for (k = j + 1; k < numCenters; k++)
|
||||
{
|
||||
distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k))));
|
||||
halfcdist[j * numCenters + k] = distance;
|
||||
halfcdist[k * numCenters + j] = distance;
|
||||
}
|
||||
}
|
||||
|
||||
/* For all centers c, compute s(c) */
|
||||
for (j = 0; j < numCenters; j++)
|
||||
{
|
||||
minDistance = DBL_MAX;
|
||||
|
||||
for (k = 0; k < numCenters; k++)
|
||||
{
|
||||
if (j == k)
|
||||
continue;
|
||||
|
||||
distance = halfcdist[j * numCenters + k];
|
||||
if (distance < minDistance)
|
||||
minDistance = distance;
|
||||
}
|
||||
|
||||
s[j] = minDistance;
|
||||
}
|
||||
|
||||
rjreset = iteration != 0;
|
||||
|
||||
for (j = 0; j < numSamples; j++)
|
||||
{
|
||||
/* Step 2: Identify all points x such that u(x) <= s(c(x)) */
|
||||
if (upperBound[j] <= s[closestCenters[j]])
|
||||
continue;
|
||||
|
||||
rj = rjreset;
|
||||
|
||||
for (k = 0; k < numCenters; k++)
|
||||
{
|
||||
/* Step 3: For all remaining points x and centers c */
|
||||
if (k == closestCenters[j])
|
||||
continue;
|
||||
|
||||
if (upperBound[j] <= lowerBound[j * numCenters + k])
|
||||
continue;
|
||||
|
||||
if (upperBound[j] <= halfcdist[closestCenters[j] * numCenters + k])
|
||||
continue;
|
||||
|
||||
vec = VectorArrayGet(samples, j);
|
||||
|
||||
/* Step 3a */
|
||||
if (rj)
|
||||
{
|
||||
dxcx = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, closestCenters[j]))));
|
||||
|
||||
/* d(x,c(x)) computed, which is a form of d(x,c) */
|
||||
lowerBound[j * numCenters + closestCenters[j]] = dxcx;
|
||||
upperBound[j] = dxcx;
|
||||
|
||||
rj = false;
|
||||
}
|
||||
else
|
||||
dxcx = upperBound[j];
|
||||
|
||||
/* Step 3b */
|
||||
if (dxcx > lowerBound[j * numCenters + k] || dxcx > halfcdist[closestCenters[j] * numCenters + k])
|
||||
{
|
||||
dxc = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k))));
|
||||
|
||||
/* d(x,c) calculated */
|
||||
lowerBound[j * numCenters + k] = dxc;
|
||||
|
||||
if (dxc < dxcx)
|
||||
{
|
||||
closestCenters[j] = k;
|
||||
|
||||
/* c(x) changed */
|
||||
upperBound[j] = dxc;
|
||||
|
||||
changes++;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Step 4: For each center c, let m(c) be mean of all points assigned */
|
||||
for (j = 0; j < numCenters; j++)
|
||||
{
|
||||
vec = VectorArrayGet(newCenters, j);
|
||||
for (k = 0; k < dimensions; k++)
|
||||
vec->x[k] = 0.0;
|
||||
|
||||
centerCounts[j] = 0;
|
||||
}
|
||||
|
||||
for (j = 0; j < numSamples; j++)
|
||||
{
|
||||
vec = VectorArrayGet(samples, j);
|
||||
closestCenter = closestCenters[j];
|
||||
|
||||
/* Increment sum and count of closest center */
|
||||
newCenter = VectorArrayGet(newCenters, closestCenter);
|
||||
for (k = 0; k < dimensions; k++)
|
||||
newCenter->x[k] += vec->x[k];
|
||||
|
||||
centerCounts[closestCenter] += 1;
|
||||
}
|
||||
|
||||
for (j = 0; j < numCenters; j++)
|
||||
{
|
||||
vec = VectorArrayGet(newCenters, j);
|
||||
|
||||
if (centerCounts[j] > 0)
|
||||
{
|
||||
for (k = 0; k < dimensions; k++)
|
||||
vec->x[k] /= centerCounts[j];
|
||||
}
|
||||
else
|
||||
{
|
||||
/* TODO Handle empty centers properly */
|
||||
for (k = 0; k < dimensions; k++)
|
||||
vec->x[k] = ((double) random()) / MAX_RANDOM_VALUE;
|
||||
}
|
||||
|
||||
/* Normalize if needed */
|
||||
if (normprocinfo != NULL)
|
||||
ApplyNorm(normprocinfo, collation, vec);
|
||||
}
|
||||
|
||||
/* Step 5 */
|
||||
for (j = 0; j < numCenters; j++)
|
||||
newcdist[j] = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(VectorArrayGet(centers, j)), PointerGetDatum(VectorArrayGet(newCenters, j))));
|
||||
|
||||
for (j = 0; j < numSamples; j++)
|
||||
{
|
||||
for (k = 0; k < numCenters; k++)
|
||||
{
|
||||
distance = lowerBound[j * numCenters + k] - newcdist[k];
|
||||
|
||||
if (distance < 0)
|
||||
distance = 0;
|
||||
|
||||
lowerBound[j * numCenters + k] = distance;
|
||||
}
|
||||
}
|
||||
|
||||
/* Step 6 */
|
||||
/* We reset r(x) before Step 3 in the next iteration */
|
||||
for (j = 0; j < numSamples; j++)
|
||||
upperBound[j] += newcdist[closestCenters[j]];
|
||||
|
||||
/* Step 7 */
|
||||
for (j = 0; j < numCenters; j++)
|
||||
memcpy(VectorArrayGet(centers, j), VectorArrayGet(newCenters, j), VECTOR_SIZE(dimensions));
|
||||
|
||||
if (changes == 0 && iteration != 0)
|
||||
break;
|
||||
}
|
||||
|
||||
pfree(newCenters);
|
||||
pfree(centerCounts);
|
||||
pfree(closestCenters);
|
||||
pfree(lowerBound);
|
||||
pfree(upperBound);
|
||||
pfree(s);
|
||||
pfree(halfcdist);
|
||||
pfree(newcdist);
|
||||
}
|
||||
|
||||
/*
|
||||
* Detect issues with centers
|
||||
*/
|
||||
static void
|
||||
CheckCenters(Relation index, VectorArray centers)
|
||||
{
|
||||
FmgrInfo *normprocinfo;
|
||||
Oid collation;
|
||||
int i;
|
||||
double norm;
|
||||
|
||||
if (centers->length != centers->maxlen)
|
||||
elog(ERROR, "Not enough centers. Please report a bug.");
|
||||
|
||||
/* Ensure no duplicate centers */
|
||||
/* Fine to sort in-place */
|
||||
qsort(centers->items, centers->length, VECTOR_SIZE(centers->dim), CompareVectors);
|
||||
for (i = 1; i < centers->length; i++)
|
||||
{
|
||||
if (CompareVectors(VectorArrayGet(centers, i), VectorArrayGet(centers, i - 1)) == 0)
|
||||
elog(ERROR, "Duplicate centers detected. Please report a bug.");
|
||||
}
|
||||
|
||||
/* Ensure no zero vectors for cosine distance */
|
||||
/* Check NORM_PROC instead of KMEANS_NORM_PROC */
|
||||
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
|
||||
if (normprocinfo != NULL)
|
||||
{
|
||||
collation = index->rd_indcollation[0];
|
||||
|
||||
for (i = 0; i < centers->length; i++)
|
||||
{
|
||||
norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(VectorArrayGet(centers, i))));
|
||||
if (norm == 0)
|
||||
elog(ERROR, "Zero norm detected. Please report a bug.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Perform naive k-means centering
|
||||
* We use spherical k-means for inner product and cosine
|
||||
*/
|
||||
void
|
||||
IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers)
|
||||
{
|
||||
if (samples->length <= centers->maxlen)
|
||||
QuickCenters(index, samples, centers);
|
||||
else
|
||||
ElkanKmeans(index, samples, centers);
|
||||
|
||||
CheckCenters(index, centers);
|
||||
}
|
||||
327
src/ivfscan.c
Normal file
327
src/ivfscan.c
Normal file
@@ -0,0 +1,327 @@
|
||||
#include "postgres.h"
|
||||
|
||||
#include "access/relscan.h"
|
||||
#include "ivfflat.h"
|
||||
#include "miscadmin.h"
|
||||
#include "storage/bufmgr.h"
|
||||
|
||||
#if PG_VERSION_NUM >= 110000
|
||||
#include "catalog/pg_operator_d.h"
|
||||
#include "catalog/pg_type_d.h"
|
||||
#else
|
||||
#include "catalog/pg_operator.h"
|
||||
#include "catalog/pg_type.h"
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Compare list distances
|
||||
*/
|
||||
static int
|
||||
CompareLists(const void *a, const void *b)
|
||||
{
|
||||
double diff = (((IvfflatScanList *) a)->distance - ((IvfflatScanList *) b)->distance);
|
||||
|
||||
if (diff > 0)
|
||||
return 1;
|
||||
|
||||
if (diff < 0)
|
||||
return -1;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Get lists and sort by distance
|
||||
*/
|
||||
static void
|
||||
GetScanLists(IndexScanDesc scan, Datum value)
|
||||
{
|
||||
Buffer cbuf;
|
||||
Page cpage;
|
||||
IvfflatList list;
|
||||
OffsetNumber offno;
|
||||
OffsetNumber maxoffno;
|
||||
BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO;
|
||||
int listCount = 0;
|
||||
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
|
||||
double distance;
|
||||
|
||||
/* Search all list pages */
|
||||
while (BlockNumberIsValid(nextblkno))
|
||||
{
|
||||
cbuf = ReadBuffer(scan->indexRelation, nextblkno);
|
||||
LockBuffer(cbuf, BUFFER_LOCK_SHARE);
|
||||
cpage = BufferGetPage(cbuf);
|
||||
|
||||
maxoffno = PageGetMaxOffsetNumber(cpage);
|
||||
|
||||
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
|
||||
{
|
||||
list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, offno));
|
||||
|
||||
/* Use procinfo from the index instead of scan key for performance */
|
||||
distance = DatumGetFloat8(FunctionCall2Coll(so->procinfo, so->collation, PointerGetDatum(&list->center), value));
|
||||
|
||||
so->lists[listCount].startPage = list->startPage;
|
||||
so->lists[listCount].distance = distance;
|
||||
listCount++;
|
||||
}
|
||||
|
||||
nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno;
|
||||
|
||||
UnlockReleaseBuffer(cbuf);
|
||||
}
|
||||
|
||||
/* Sort by distance */
|
||||
qsort(so->lists, listCount, sizeof(IvfflatScanList), CompareLists);
|
||||
|
||||
if (so->probes > listCount)
|
||||
so->probes = listCount;
|
||||
}
|
||||
|
||||
/*
|
||||
* Get items
|
||||
*/
|
||||
static void
|
||||
GetScanItems(IndexScanDesc scan, Datum value)
|
||||
{
|
||||
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
|
||||
Buffer buf;
|
||||
Page page;
|
||||
IndexTuple itup;
|
||||
BlockNumber searchPage;
|
||||
OffsetNumber offno;
|
||||
OffsetNumber maxoffno;
|
||||
Datum datum;
|
||||
bool isnull;
|
||||
int i;
|
||||
TupleDesc tupdesc = RelationGetDescr(scan->indexRelation);
|
||||
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
TupleTableSlot *slot = MakeSingleTupleTableSlot(so->tupdesc, &TTSOpsVirtual);
|
||||
#else
|
||||
TupleTableSlot *slot = MakeSingleTupleTableSlot(so->tupdesc);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Reuse same set of shared buffers for scan
|
||||
*
|
||||
* See postgres/src/backend/storage/buffer/README for description
|
||||
*/
|
||||
BufferAccessStrategy bas = GetAccessStrategy(BAS_BULKREAD);
|
||||
|
||||
/* Search closest probes lists */
|
||||
for (i = 0; i < so->probes; i++)
|
||||
{
|
||||
searchPage = so->lists[i].startPage;
|
||||
|
||||
/* Search all entry pages for list */
|
||||
while (BlockNumberIsValid(searchPage))
|
||||
{
|
||||
buf = ReadBufferExtended(scan->indexRelation, MAIN_FORKNUM, searchPage, RBM_NORMAL, bas);
|
||||
LockBuffer(buf, BUFFER_LOCK_SHARE);
|
||||
page = BufferGetPage(buf);
|
||||
maxoffno = PageGetMaxOffsetNumber(page);
|
||||
|
||||
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
|
||||
{
|
||||
itup = (IndexTuple) PageGetItem(page, PageGetItemId(page, offno));
|
||||
datum = index_getattr(itup, 1, tupdesc, &isnull);
|
||||
|
||||
/*
|
||||
* Add virtual tuple
|
||||
*
|
||||
* Use procinfo from the index instead of scan key for
|
||||
* performance
|
||||
*/
|
||||
ExecClearTuple(slot);
|
||||
slot->tts_values[0] = FunctionCall2Coll(so->procinfo, so->collation, datum, value);
|
||||
slot->tts_isnull[0] = false;
|
||||
slot->tts_values[1] = Int32GetDatum((int) ItemPointerGetBlockNumberNoCheck(&itup->t_tid));
|
||||
slot->tts_isnull[1] = false;
|
||||
slot->tts_values[2] = Int32GetDatum((int) ItemPointerGetOffsetNumberNoCheck(&itup->t_tid));
|
||||
slot->tts_isnull[2] = false;
|
||||
slot->tts_values[3] = Int32GetDatum((int) searchPage);
|
||||
slot->tts_isnull[3] = false;
|
||||
ExecStoreVirtualTuple(slot);
|
||||
|
||||
tuplesort_puttupleslot(so->sortstate, slot);
|
||||
}
|
||||
|
||||
searchPage = IvfflatPageGetOpaque(page)->nextblkno;
|
||||
|
||||
UnlockReleaseBuffer(buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Prepare for an index scan
|
||||
*/
|
||||
IndexScanDesc
|
||||
ivfflatbeginscan(Relation index, int nkeys, int norderbys)
|
||||
{
|
||||
IndexScanDesc scan;
|
||||
IvfflatScanOpaque so;
|
||||
int lists;
|
||||
AttrNumber attNums[] = {1};
|
||||
Oid sortOperators[] = {Float8LessOperator};
|
||||
Oid sortCollations[] = {InvalidOid};
|
||||
bool nullsFirstFlags[] = {false};
|
||||
|
||||
scan = RelationGetIndexScan(index, nkeys, norderbys);
|
||||
lists = IvfflatGetLists(scan->indexRelation);
|
||||
|
||||
so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + lists * sizeof(IvfflatScanList));
|
||||
so->buf = InvalidBuffer;
|
||||
so->first = true;
|
||||
|
||||
/* Set support functions */
|
||||
so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
|
||||
so->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
|
||||
so->collation = index->rd_indcollation[0];
|
||||
|
||||
/* Create tuple description for sorting */
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
so->tupdesc = CreateTemplateTupleDesc(4);
|
||||
#else
|
||||
so->tupdesc = CreateTemplateTupleDesc(4, false);
|
||||
#endif
|
||||
TupleDescInitEntry(so->tupdesc, (AttrNumber) 1, "distance", FLOAT8OID, -1, 0);
|
||||
TupleDescInitEntry(so->tupdesc, (AttrNumber) 2, "blkno", INT4OID, -1, 0);
|
||||
TupleDescInitEntry(so->tupdesc, (AttrNumber) 3, "offset", INT4OID, -1, 0);
|
||||
TupleDescInitEntry(so->tupdesc, (AttrNumber) 4, "indexblkno", INT4OID, -1, 0);
|
||||
|
||||
/* Prep sort */
|
||||
#if PG_VERSION_NUM >= 110000
|
||||
so->sortstate = tuplesort_begin_heap(so->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, work_mem, NULL, false);
|
||||
#else
|
||||
so->sortstate = tuplesort_begin_heap(so->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, work_mem, false);
|
||||
#endif
|
||||
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
so->slot = MakeSingleTupleTableSlot(so->tupdesc, &TTSOpsMinimalTuple);
|
||||
#else
|
||||
so->slot = MakeSingleTupleTableSlot(so->tupdesc);
|
||||
#endif
|
||||
|
||||
scan->opaque = so;
|
||||
|
||||
return scan;
|
||||
}
|
||||
|
||||
/*
|
||||
* Start or restart an index scan
|
||||
*/
|
||||
void
|
||||
ivfflatrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys)
|
||||
{
|
||||
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
|
||||
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
if (!so->first)
|
||||
tuplesort_reset(so->sortstate);
|
||||
#endif
|
||||
|
||||
so->first = true;
|
||||
so->probes = ivfflat_probes;
|
||||
|
||||
if (keys && scan->numberOfKeys > 0)
|
||||
memmove(scan->keyData, keys, scan->numberOfKeys * sizeof(ScanKeyData));
|
||||
|
||||
if (orderbys && scan->numberOfOrderBys > 0)
|
||||
memmove(scan->orderByData, orderbys, scan->numberOfOrderBys * sizeof(ScanKeyData));
|
||||
}
|
||||
|
||||
/*
|
||||
* Fetch the next tuple in the given scan
|
||||
*/
|
||||
bool
|
||||
ivfflatgettuple(IndexScanDesc scan, ScanDirection dir)
|
||||
{
|
||||
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
|
||||
|
||||
/*
|
||||
* Index can be used to scan backward, but Postgres doesn't support
|
||||
* backward scan on operators
|
||||
*/
|
||||
Assert(ScanDirectionIsForward(dir));
|
||||
|
||||
if (so->first)
|
||||
{
|
||||
Datum value;
|
||||
|
||||
/* No items will match if null */
|
||||
if (scan->orderByData->sk_flags & SK_ISNULL)
|
||||
return false;
|
||||
|
||||
value = scan->orderByData->sk_argument;
|
||||
|
||||
if (so->normprocinfo != NULL)
|
||||
{
|
||||
/* No items will match if normalization fails */
|
||||
if (!IvfflatNormValue(so->normprocinfo, so->collation, &value, NULL))
|
||||
return false;
|
||||
}
|
||||
|
||||
GetScanLists(scan, value);
|
||||
GetScanItems(scan, value);
|
||||
tuplesort_performsort(so->sortstate);
|
||||
so->first = false;
|
||||
|
||||
/* Clean up if we allocated a new value */
|
||||
if (value != scan->orderByData->sk_argument)
|
||||
pfree(DatumGetPointer(value));
|
||||
}
|
||||
|
||||
#if PG_VERSION_NUM >= 100000
|
||||
if (tuplesort_gettupleslot(so->sortstate, true, false, so->slot, NULL))
|
||||
#else
|
||||
if (tuplesort_gettupleslot(so->sortstate, true, so->slot, NULL))
|
||||
#endif
|
||||
{
|
||||
BlockNumber blkno = DatumGetInt32(slot_getattr(so->slot, 2, &so->isnull));
|
||||
OffsetNumber offset = DatumGetInt32(slot_getattr(so->slot, 3, &so->isnull));
|
||||
BlockNumber indexblkno = DatumGetInt32(slot_getattr(so->slot, 4, &so->isnull));
|
||||
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
ItemPointerSet(&scan->xs_heaptid, blkno, offset);
|
||||
#else
|
||||
ItemPointerSet(&scan->xs_ctup.t_self, blkno, offset);
|
||||
#endif
|
||||
|
||||
if (BufferIsValid(so->buf))
|
||||
ReleaseBuffer(so->buf);
|
||||
|
||||
/*
|
||||
* An index scan must maintain a pin on the index page holding the
|
||||
* item last returned by amgettuple
|
||||
*
|
||||
* https://www.postgresql.org/docs/current/index-locking.html
|
||||
*/
|
||||
so->buf = ReadBuffer(scan->indexRelation, indexblkno);
|
||||
|
||||
scan->xs_recheckorderby = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/*
|
||||
* End a scan and release resources
|
||||
*/
|
||||
void
|
||||
ivfflatendscan(IndexScanDesc scan)
|
||||
{
|
||||
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
|
||||
|
||||
/* Release pin */
|
||||
if (BufferIsValid(so->buf))
|
||||
ReleaseBuffer(so->buf);
|
||||
|
||||
tuplesort_end(so->sortstate);
|
||||
|
||||
pfree(so);
|
||||
scan->opaque = NULL;
|
||||
}
|
||||
176
src/ivfutils.c
Normal file
176
src/ivfutils.c
Normal file
@@ -0,0 +1,176 @@
|
||||
#include "postgres.h"
|
||||
|
||||
#include "ivfflat.h"
|
||||
#include "storage/bufmgr.h"
|
||||
#include "vector.h"
|
||||
|
||||
/*
|
||||
* Allocate a vector array
|
||||
*/
|
||||
VectorArray
|
||||
VectorArrayInit(int maxlen, int dimensions)
|
||||
{
|
||||
VectorArray res = palloc0(VECTOR_ARRAY_SIZE(maxlen, dimensions));
|
||||
|
||||
res->length = 0;
|
||||
res->maxlen = maxlen;
|
||||
res->dim = dimensions;
|
||||
return res;
|
||||
}
|
||||
|
||||
/*
|
||||
* Print vector array - useful for debugging
|
||||
*/
|
||||
void
|
||||
PrintVectorArray(char *msg, VectorArray arr)
|
||||
{
|
||||
int i;
|
||||
|
||||
for (i = 0; i < arr->length; i++)
|
||||
PrintVector(msg, VectorArrayGet(arr, i));
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the number of lists in the index
|
||||
*/
|
||||
int
|
||||
IvfflatGetLists(Relation index)
|
||||
{
|
||||
IvfflatOptions *opts = (IvfflatOptions *) index->rd_options;
|
||||
|
||||
if (opts)
|
||||
return opts->lists;
|
||||
|
||||
return IVFFLAT_DEFAULT_LISTS;
|
||||
}
|
||||
|
||||
/*
|
||||
* Get proc
|
||||
*/
|
||||
FmgrInfo *
|
||||
IvfflatOptionalProcInfo(Relation rel, uint16 procnum)
|
||||
{
|
||||
if (!OidIsValid(index_getprocid(rel, 1, procnum)))
|
||||
return NULL;
|
||||
|
||||
return index_getprocinfo(rel, 1, procnum);
|
||||
}
|
||||
|
||||
/*
|
||||
* Divide by the norm
|
||||
*
|
||||
* Returns false if value should not be indexed
|
||||
*
|
||||
* The caller needs to free the pointer stored in value
|
||||
* if it's different than the original value
|
||||
*/
|
||||
bool
|
||||
IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result)
|
||||
{
|
||||
Vector *v;
|
||||
int i;
|
||||
double norm;
|
||||
|
||||
norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
|
||||
|
||||
if (norm > 0)
|
||||
{
|
||||
v = (Vector *) DatumGetPointer(*value);
|
||||
|
||||
if (result == NULL)
|
||||
result = InitVector(v->dim);
|
||||
|
||||
for (i = 0; i < v->dim; i++)
|
||||
result->x[i] = v->x[i] / norm;
|
||||
|
||||
*value = PointerGetDatum(result);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/*
|
||||
* New buffer
|
||||
*/
|
||||
Buffer
|
||||
IvfflatNewBuffer(Relation index, ForkNumber forkNum)
|
||||
{
|
||||
Buffer buf = ReadBufferExtended(index, forkNum, P_NEW, RBM_NORMAL, NULL);
|
||||
|
||||
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
|
||||
return buf;
|
||||
}
|
||||
|
||||
/*
|
||||
* Init page
|
||||
*/
|
||||
void
|
||||
IvfflatInitPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state)
|
||||
{
|
||||
*state = GenericXLogStart(index);
|
||||
*page = GenericXLogRegisterBuffer(*state, *buf, GENERIC_XLOG_FULL_IMAGE);
|
||||
PageInit(*page, BufferGetPageSize(*buf), sizeof(IvfflatPageOpaqueData));
|
||||
IvfflatPageGetOpaque(*page)->nextblkno = InvalidBlockNumber;
|
||||
IvfflatPageGetOpaque(*page)->page_id = IVFFLAT_PAGE_ID;
|
||||
}
|
||||
|
||||
/*
|
||||
* Commit buffer
|
||||
*/
|
||||
void
|
||||
IvfflatCommitBuffer(Buffer buf, GenericXLogState *state)
|
||||
{
|
||||
MarkBufferDirty(buf);
|
||||
GenericXLogFinish(state);
|
||||
UnlockReleaseBuffer(buf);
|
||||
}
|
||||
|
||||
/*
|
||||
* Add a new page
|
||||
*
|
||||
* The order is very important!!
|
||||
*/
|
||||
void
|
||||
IvfflatAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, ForkNumber forkNum)
|
||||
{
|
||||
Buffer prevbuf = *buf;
|
||||
|
||||
/* Get new buffer */
|
||||
*buf = IvfflatNewBuffer(index, forkNum);
|
||||
|
||||
/* Update and commit previous buffer */
|
||||
IvfflatPageGetOpaque(*page)->nextblkno = BufferGetBlockNumber(*buf);
|
||||
IvfflatCommitBuffer(prevbuf, *state);
|
||||
|
||||
/* Init new page */
|
||||
IvfflatInitPage(index, buf, page, state);
|
||||
}
|
||||
|
||||
/*
|
||||
* Update the start or insert page of a list
|
||||
*/
|
||||
void
|
||||
IvfflatUpdateList(Relation index, GenericXLogState *state, ListInfo listInfo,
|
||||
BlockNumber insertPage, BlockNumber startPage, ForkNumber forkNum)
|
||||
{
|
||||
Buffer buf;
|
||||
Page page;
|
||||
IvfflatList list;
|
||||
|
||||
buf = ReadBufferExtended(index, forkNum, listInfo.blkno, RBM_NORMAL, NULL);
|
||||
LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE);
|
||||
state = GenericXLogStart(index);
|
||||
page = GenericXLogRegisterBuffer(state, buf, 0);
|
||||
list = (IvfflatList) PageGetItem(page, PageGetItemId(page, listInfo.offno));
|
||||
|
||||
if (BlockNumberIsValid(insertPage))
|
||||
list->insertPage = insertPage;
|
||||
|
||||
if (BlockNumberIsValid(startPage))
|
||||
list->startPage = startPage;
|
||||
|
||||
/* Could only commit if changed, but extra complexity isn't needed */
|
||||
IvfflatCommitBuffer(buf, state);
|
||||
}
|
||||
151
src/ivfvacuum.c
Normal file
151
src/ivfvacuum.c
Normal file
@@ -0,0 +1,151 @@
|
||||
#include "postgres.h"
|
||||
|
||||
#include "commands/vacuum.h"
|
||||
#include "ivfflat.h"
|
||||
#include "storage/bufmgr.h"
|
||||
|
||||
/*
|
||||
* Bulk delete tuples from the index
|
||||
*/
|
||||
IndexBulkDeleteResult *
|
||||
ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats,
|
||||
IndexBulkDeleteCallback callback, void *callback_state)
|
||||
{
|
||||
Relation index = info->index;
|
||||
Buffer cbuf;
|
||||
Page cpage;
|
||||
Buffer buf;
|
||||
Page page;
|
||||
IvfflatList list;
|
||||
IndexTuple itup;
|
||||
ItemPointer htup;
|
||||
OffsetNumber deletable[MaxOffsetNumber];
|
||||
int ndeletable;
|
||||
OffsetNumber startPages[MaxOffsetNumber];
|
||||
BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO;
|
||||
BlockNumber searchPage;
|
||||
BlockNumber insertPage;
|
||||
GenericXLogState *state;
|
||||
OffsetNumber coffno;
|
||||
OffsetNumber cmaxoffno;
|
||||
OffsetNumber offno;
|
||||
OffsetNumber maxoffno;
|
||||
ListInfo listInfo;
|
||||
BufferAccessStrategy bas = GetAccessStrategy(BAS_BULKREAD);
|
||||
|
||||
if (stats == NULL)
|
||||
stats = (IndexBulkDeleteResult *) palloc0(sizeof(IndexBulkDeleteResult));
|
||||
|
||||
/* Iterate over list pages */
|
||||
while (BlockNumberIsValid(nextblkno))
|
||||
{
|
||||
cbuf = ReadBuffer(index, nextblkno);
|
||||
LockBuffer(cbuf, BUFFER_LOCK_SHARE);
|
||||
cpage = BufferGetPage(cbuf);
|
||||
|
||||
cmaxoffno = PageGetMaxOffsetNumber(cpage);
|
||||
|
||||
/* Iterate over lists */
|
||||
for (coffno = FirstOffsetNumber; coffno <= cmaxoffno; coffno = OffsetNumberNext(coffno))
|
||||
{
|
||||
list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, coffno));
|
||||
startPages[coffno - FirstOffsetNumber] = list->startPage;
|
||||
}
|
||||
|
||||
listInfo.blkno = nextblkno;
|
||||
nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno;
|
||||
|
||||
UnlockReleaseBuffer(cbuf);
|
||||
|
||||
for (coffno = FirstOffsetNumber; coffno <= cmaxoffno; coffno = OffsetNumberNext(coffno))
|
||||
{
|
||||
searchPage = startPages[coffno - FirstOffsetNumber];
|
||||
insertPage = InvalidBlockNumber;
|
||||
|
||||
/* Iterate over entry pages */
|
||||
while (BlockNumberIsValid(searchPage))
|
||||
{
|
||||
vacuum_delay_point();
|
||||
|
||||
buf = ReadBufferExtended(index, MAIN_FORKNUM, searchPage, RBM_NORMAL, bas);
|
||||
|
||||
/*
|
||||
* ambulkdelete cannot delete entries from pages that are
|
||||
* pinned by other backends
|
||||
*
|
||||
* https://www.postgresql.org/docs/current/index-locking.html
|
||||
*/
|
||||
LockBufferForCleanup(buf);
|
||||
|
||||
state = GenericXLogStart(index);
|
||||
page = GenericXLogRegisterBuffer(state, buf, 0);
|
||||
|
||||
maxoffno = PageGetMaxOffsetNumber(page);
|
||||
ndeletable = 0;
|
||||
|
||||
/* Find deleted tuples */
|
||||
for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno))
|
||||
{
|
||||
itup = (IndexTuple) PageGetItem(page, PageGetItemId(page, offno));
|
||||
htup = &(itup->t_tid);
|
||||
|
||||
if (callback(htup, callback_state))
|
||||
{
|
||||
deletable[ndeletable++] = offno;
|
||||
stats->tuples_removed++;
|
||||
}
|
||||
else
|
||||
stats->num_index_tuples++;
|
||||
}
|
||||
|
||||
searchPage = IvfflatPageGetOpaque(page)->nextblkno;
|
||||
|
||||
if (ndeletable > 0)
|
||||
{
|
||||
/* Delete tuples */
|
||||
PageIndexMultiDelete(page, deletable, ndeletable);
|
||||
MarkBufferDirty(buf);
|
||||
GenericXLogFinish(state);
|
||||
|
||||
/* Set to first free page */
|
||||
if (!BlockNumberIsValid(insertPage))
|
||||
insertPage = searchPage;
|
||||
}
|
||||
else
|
||||
GenericXLogAbort(state);
|
||||
|
||||
UnlockReleaseBuffer(buf);
|
||||
}
|
||||
|
||||
/*
|
||||
* Update after all tuples deleted.
|
||||
*
|
||||
* We don't add or delete items from lists pages, so offset won't
|
||||
* change.
|
||||
*/
|
||||
if (!BlockNumberIsValid(insertPage))
|
||||
{
|
||||
listInfo.offno = coffno;
|
||||
IvfflatUpdateList(index, state, listInfo, insertPage, InvalidBlockNumber, MAIN_FORKNUM);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stats;
|
||||
}
|
||||
|
||||
/*
|
||||
* Clean up after a VACUUM operation
|
||||
*/
|
||||
IndexBulkDeleteResult *
|
||||
ivfflatvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats)
|
||||
{
|
||||
Relation rel = info->index;
|
||||
|
||||
if (stats == NULL)
|
||||
return NULL;
|
||||
|
||||
stats->num_pages = RelationGetNumberOfBlocks(rel);
|
||||
|
||||
return stats;
|
||||
}
|
||||
610
src/vector.c
Normal file
610
src/vector.c
Normal file
@@ -0,0 +1,610 @@
|
||||
#include "postgres.h"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include "vector.h"
|
||||
#include "fmgr.h"
|
||||
#include "catalog/pg_type.h"
|
||||
#include "lib/stringinfo.h"
|
||||
#include "utils/array.h"
|
||||
#include "utils/builtins.h"
|
||||
#include "utils/lsyscache.h"
|
||||
|
||||
#if PG_VERSION_NUM >= 120000
|
||||
#include "utils/float.h"
|
||||
#endif
|
||||
|
||||
PG_MODULE_MAGIC;
|
||||
|
||||
/*
|
||||
* Ensure same dimensions
|
||||
*/
|
||||
static inline void
|
||||
CheckDims(Vector * a, Vector * b)
|
||||
{
|
||||
if (a->dim != b->dim)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("different vector dimensions %d and %d", a->dim, b->dim)));
|
||||
}
|
||||
|
||||
/*
|
||||
* Ensure expected dimension
|
||||
*/
|
||||
static inline void
|
||||
CheckExpectedDim(int32 typmod, int dim)
|
||||
{
|
||||
if (typmod != -1 && typmod != dim)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("expected %d dimensions, not %d", typmod, dim)));
|
||||
}
|
||||
|
||||
/*
|
||||
* Ensure finite elements
|
||||
*/
|
||||
static inline void
|
||||
CheckElement(float value)
|
||||
{
|
||||
if (isnan(value))
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("NaN not allowed in vector")));
|
||||
|
||||
|
||||
if (isinf(value))
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("infinite value not allowed in vector")));
|
||||
}
|
||||
|
||||
/*
|
||||
* Print vector - useful for debugging
|
||||
*/
|
||||
void
|
||||
PrintVector(char *msg, Vector * vector)
|
||||
{
|
||||
StringInfoData buf;
|
||||
int dim = vector->dim;
|
||||
int i;
|
||||
|
||||
initStringInfo(&buf);
|
||||
|
||||
appendStringInfoChar(&buf, '[');
|
||||
for (i = 0; i < dim; i++)
|
||||
{
|
||||
if (i > 0)
|
||||
appendStringInfoString(&buf, ",");
|
||||
appendStringInfoString(&buf, float8out_internal(vector->x[i]));
|
||||
}
|
||||
appendStringInfoChar(&buf, ']');
|
||||
|
||||
elog(INFO, "%s = %s", msg, buf.data);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert textual representation to internal representation
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_in);
|
||||
Datum
|
||||
vector_in(PG_FUNCTION_ARGS)
|
||||
{
|
||||
char *str = PG_GETARG_CSTRING(0);
|
||||
int32 typmod = PG_GETARG_INT32(2);
|
||||
int i;
|
||||
double x[VECTOR_MAX_DIM];
|
||||
int dim = 0;
|
||||
char *pt;
|
||||
char *stringEnd;
|
||||
Vector *result;
|
||||
|
||||
if (*str != '[')
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
||||
errmsg("malformed vector literal: \"%s\"", str),
|
||||
errdetail("Vector contents must start with \"[\".")));
|
||||
|
||||
str++;
|
||||
pt = strtok(str, ",");
|
||||
stringEnd = pt;
|
||||
|
||||
while (pt != NULL && *stringEnd != ']')
|
||||
{
|
||||
if (dim == VECTOR_MAX_DIM)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
|
||||
errmsg("vector cannot have more than %d dimensions", VECTOR_MAX_DIM)));
|
||||
|
||||
x[dim] = strtod(pt, &stringEnd);
|
||||
CheckElement(x[dim]);
|
||||
dim++;
|
||||
|
||||
if (stringEnd == pt)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
||||
errmsg("invalid input syntax for type vector: \"%s\"", pt)));
|
||||
|
||||
if (*stringEnd != '\0' && *stringEnd != ']')
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
||||
errmsg("invalid input syntax for type vector: \"%s\"", pt)));
|
||||
|
||||
pt = strtok(NULL, ",");
|
||||
}
|
||||
|
||||
if (*stringEnd != ']')
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
||||
errmsg("malformed vector literal"),
|
||||
errdetail("Unexpected end of input.")));
|
||||
|
||||
if (stringEnd[1] != '\0')
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
|
||||
errmsg("malformed vector literal"),
|
||||
errdetail("Junk after closing right brace.")));
|
||||
|
||||
if (dim < 1)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("vector must have at least 1 dimension")));
|
||||
|
||||
CheckExpectedDim(typmod, dim);
|
||||
|
||||
result = InitVector(dim);
|
||||
for (i = 0; i < dim; i++)
|
||||
result->x[i] = x[i];
|
||||
|
||||
PG_RETURN_POINTER(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert internal representation to textual representation
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_out);
|
||||
Datum
|
||||
vector_out(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *vector = PG_GETARG_VECTOR_P(0);
|
||||
StringInfoData buf;
|
||||
int dim = vector->dim;
|
||||
int i;
|
||||
|
||||
initStringInfo(&buf);
|
||||
|
||||
appendStringInfoChar(&buf, '[');
|
||||
for (i = 0; i < dim; i++)
|
||||
{
|
||||
if (i > 0)
|
||||
appendStringInfoString(&buf, ",");
|
||||
|
||||
appendStringInfoString(&buf, float8out_internal(vector->x[i]));
|
||||
}
|
||||
appendStringInfoChar(&buf, ']');
|
||||
|
||||
PG_FREE_IF_COPY(vector, 0);
|
||||
PG_RETURN_CSTRING(buf.data);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert type modifier
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_typmod_in);
|
||||
Datum
|
||||
vector_typmod_in(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *ta = PG_GETARG_ARRAYTYPE_P(0);
|
||||
int32 *tl;
|
||||
int n;
|
||||
|
||||
tl = ArrayGetIntegerTypmods(ta, &n);
|
||||
|
||||
if (n != 1)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("invalid type modifier")));
|
||||
|
||||
if (*tl < 1)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("dimensions for type vector must be at least 1")));
|
||||
|
||||
if (*tl > VECTOR_MAX_DIM)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("dimensions for type vector cannot exceed %d", VECTOR_MAX_DIM)));
|
||||
|
||||
PG_RETURN_INT32(*tl);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert vector to vector
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector);
|
||||
Datum
|
||||
vector(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *arg = PG_GETARG_VECTOR_P(0);
|
||||
int32 typmod = PG_GETARG_INT32(1);
|
||||
|
||||
CheckExpectedDim(typmod, arg->dim);
|
||||
|
||||
PG_RETURN_POINTER(arg);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert array to vector
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(array_to_vector);
|
||||
Datum
|
||||
array_to_vector(PG_FUNCTION_ARGS)
|
||||
{
|
||||
ArrayType *array = PG_GETARG_ARRAYTYPE_P(0);
|
||||
int32 typmod = PG_GETARG_INT32(1);
|
||||
int i;
|
||||
Vector *result;
|
||||
int16 typlen;
|
||||
bool typbyval;
|
||||
char typalign;
|
||||
Datum *elemsp;
|
||||
bool *nullsp;
|
||||
int nelemsp;
|
||||
|
||||
if (ARR_NDIM(array) > 1)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("array must be 1-D")));
|
||||
|
||||
get_typlenbyvalalign(ARR_ELEMTYPE(array), &typlen, &typbyval, &typalign);
|
||||
deconstruct_array(array, ARR_ELEMTYPE(array), typlen, typbyval, typalign, &elemsp, &nullsp, &nelemsp);
|
||||
|
||||
if (typmod == -1)
|
||||
{
|
||||
if (nelemsp < 1)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("vector must have at least 1 dimension")));
|
||||
|
||||
if (nelemsp > VECTOR_MAX_DIM)
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
|
||||
errmsg("vector cannot have more than %d dimensions", VECTOR_MAX_DIM)));
|
||||
}
|
||||
else
|
||||
CheckExpectedDim(typmod, nelemsp);
|
||||
|
||||
result = InitVector(nelemsp);
|
||||
for (i = 0; i < nelemsp; i++)
|
||||
{
|
||||
if (nullsp[i])
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
|
||||
errmsg("array must not containing NULLs")));
|
||||
|
||||
if (ARR_ELEMTYPE(array) == INT4OID)
|
||||
result->x[i] = DatumGetInt32(elemsp[i]);
|
||||
else if (ARR_ELEMTYPE(array) == FLOAT8OID)
|
||||
result->x[i] = DatumGetFloat8(elemsp[i]);
|
||||
else if (ARR_ELEMTYPE(array) == FLOAT4OID)
|
||||
result->x[i] = DatumGetFloat4(elemsp[i]);
|
||||
else
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_EXCEPTION),
|
||||
errmsg("unsupported array type")));
|
||||
|
||||
CheckElement(result->x[i]);
|
||||
}
|
||||
|
||||
PG_RETURN_POINTER(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the L2 distance between vectors
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(l2_distance);
|
||||
Datum
|
||||
l2_distance(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = PG_GETARG_VECTOR_P(0);
|
||||
Vector *b = PG_GETARG_VECTOR_P(1);
|
||||
double distance = 0.0;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
distance += pow(a->x[i] - b->x[i], 2);
|
||||
|
||||
PG_RETURN_FLOAT8(sqrt(distance));
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the L2 squared distance between vectors
|
||||
* This saves a sqrt calculation
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_l2_squared_distance);
|
||||
Datum
|
||||
vector_l2_squared_distance(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = PG_GETARG_VECTOR_P(0);
|
||||
Vector *b = PG_GETARG_VECTOR_P(1);
|
||||
double distance = 0.0;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
distance += pow(a->x[i] - b->x[i], 2);
|
||||
|
||||
PG_RETURN_FLOAT8(distance);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the inner product of two vectors
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(inner_product);
|
||||
Datum
|
||||
inner_product(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = PG_GETARG_VECTOR_P(0);
|
||||
Vector *b = PG_GETARG_VECTOR_P(1);
|
||||
double distance = 0.0;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
distance += a->x[i] * b->x[i];
|
||||
|
||||
PG_RETURN_FLOAT8(distance);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the negative inner product of two vectors
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_negative_inner_product);
|
||||
Datum
|
||||
vector_negative_inner_product(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = PG_GETARG_VECTOR_P(0);
|
||||
Vector *b = PG_GETARG_VECTOR_P(1);
|
||||
double distance = 0.0;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
distance += a->x[i] * b->x[i];
|
||||
|
||||
PG_RETURN_FLOAT8(distance * -1);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the cosine distance between two vectors
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(cosine_distance);
|
||||
Datum
|
||||
cosine_distance(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = PG_GETARG_VECTOR_P(0);
|
||||
Vector *b = PG_GETARG_VECTOR_P(1);
|
||||
double distance = 0.0;
|
||||
double norma = 0.0;
|
||||
double normb = 0.0;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
{
|
||||
distance += a->x[i] * b->x[i];
|
||||
norma += pow(a->x[i], 2);
|
||||
normb += pow(b->x[i], 2);
|
||||
}
|
||||
|
||||
PG_RETURN_FLOAT8(1 - (distance / (sqrt(norma) * sqrt(normb))));
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the distance for spherical k-means
|
||||
* Currently uses angular distance since needs to satisfy triangle inequality
|
||||
* Assumes inputs are unit vectors (skips norm)
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_spherical_distance);
|
||||
Datum
|
||||
vector_spherical_distance(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = PG_GETARG_VECTOR_P(0);
|
||||
Vector *b = PG_GETARG_VECTOR_P(1);
|
||||
double distance = 0.0;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
distance += a->x[i] * b->x[i];
|
||||
|
||||
/* Prevent NaN with acos with loss of precision */
|
||||
if (distance > 1)
|
||||
distance = 1;
|
||||
else if (distance < -1)
|
||||
distance = -1;
|
||||
|
||||
PG_RETURN_FLOAT8(acos(distance) / M_PI);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the dimensions of a vector
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_dims);
|
||||
Datum
|
||||
vector_dims(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = PG_GETARG_VECTOR_P(0);
|
||||
|
||||
PG_RETURN_INT32(a->dim);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the L2 norm of a vector
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_norm);
|
||||
Datum
|
||||
vector_norm(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = PG_GETARG_VECTOR_P(0);
|
||||
double norm = 0.0;
|
||||
|
||||
for (int i = 0; i < a->dim; i++)
|
||||
norm += pow(a->x[i], 2);
|
||||
|
||||
PG_RETURN_FLOAT8(sqrt(norm));
|
||||
}
|
||||
|
||||
/*
|
||||
* Add vectors
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_add);
|
||||
Datum
|
||||
vector_add(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = PG_GETARG_VECTOR_P(0);
|
||||
Vector *b = PG_GETARG_VECTOR_P(1);
|
||||
Vector *result;
|
||||
int i;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
result = InitVector(a->dim);
|
||||
for (i = 0; i < a->dim; i++)
|
||||
result->x[i] = a->x[i] + b->x[i];
|
||||
|
||||
PG_RETURN_POINTER(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Subtract vectors
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_sub);
|
||||
Datum
|
||||
vector_sub(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = PG_GETARG_VECTOR_P(0);
|
||||
Vector *b = PG_GETARG_VECTOR_P(1);
|
||||
Vector *result;
|
||||
int i;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
result = InitVector(a->dim);
|
||||
for (i = 0; i < a->dim; i++)
|
||||
result->x[i] = a->x[i] - b->x[i];
|
||||
|
||||
PG_RETURN_POINTER(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Internal helper to compare vectors
|
||||
*/
|
||||
int
|
||||
vector_cmp_internal(Vector * a, Vector * b)
|
||||
{
|
||||
int i;
|
||||
|
||||
CheckDims(a, b);
|
||||
|
||||
for (i = 0; i < a->dim; i++)
|
||||
{
|
||||
if (a->x[i] < b->x[i])
|
||||
return -1;
|
||||
|
||||
if (a->x[i] > b->x[i])
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Less than
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_lt);
|
||||
Datum
|
||||
vector_lt(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
|
||||
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
|
||||
|
||||
PG_RETURN_BOOL(vector_cmp_internal(a, b) < 0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Less than or equal
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_le);
|
||||
Datum
|
||||
vector_le(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
|
||||
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
|
||||
|
||||
PG_RETURN_BOOL(vector_cmp_internal(a, b) <= 0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Equal
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_eq);
|
||||
Datum
|
||||
vector_eq(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
|
||||
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
|
||||
|
||||
PG_RETURN_BOOL(vector_cmp_internal(a, b) == 0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Not equal
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_ne);
|
||||
Datum
|
||||
vector_ne(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
|
||||
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
|
||||
|
||||
PG_RETURN_BOOL(vector_cmp_internal(a, b) != 0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Greater than or equal
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_ge);
|
||||
Datum
|
||||
vector_ge(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
|
||||
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
|
||||
|
||||
PG_RETURN_BOOL(vector_cmp_internal(a, b) >= 0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Greater than
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_gt);
|
||||
Datum
|
||||
vector_gt(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
|
||||
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
|
||||
|
||||
PG_RETURN_BOOL(vector_cmp_internal(a, b) > 0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Compare vectors
|
||||
*/
|
||||
PG_FUNCTION_INFO_V1(vector_cmp);
|
||||
Datum
|
||||
vector_cmp(PG_FUNCTION_ARGS)
|
||||
{
|
||||
Vector *a = (Vector *) PG_GETARG_VECTOR_P(0);
|
||||
Vector *b = (Vector *) PG_GETARG_VECTOR_P(1);
|
||||
|
||||
PG_RETURN_INT32(vector_cmp_internal(a, b));
|
||||
}
|
||||
41
src/vector.h
Normal file
41
src/vector.h
Normal file
@@ -0,0 +1,41 @@
|
||||
#ifndef VECTOR_H
|
||||
#define VECTOR_H
|
||||
|
||||
#include "postgres.h"
|
||||
|
||||
#define VECTOR_MAX_DIM 1024
|
||||
|
||||
#define VECTOR_SIZE(_dim) (offsetof(Vector, x) + sizeof(float)*(_dim))
|
||||
#define DatumGetVector(x) ((Vector *) PG_DETOAST_DATUM(x))
|
||||
#define PG_GETARG_VECTOR_P(x) DatumGetVector(PG_GETARG_DATUM(x))
|
||||
#define PG_RETURN_VECTOR_P(x) PG_RETURN_POINTER(x)
|
||||
|
||||
typedef struct Vector
|
||||
{
|
||||
int32 vl_len_; /* varlena header (do not touch directly!) */
|
||||
int16 dim; /* number of dimensions */
|
||||
int16 unused;
|
||||
float x[FLEXIBLE_ARRAY_MEMBER];
|
||||
} Vector;
|
||||
|
||||
void PrintVector(char *msg, Vector * vector);
|
||||
int vector_cmp_internal(Vector * a, Vector * b);
|
||||
|
||||
/*
|
||||
* Allocate and initialize a new vector
|
||||
*/
|
||||
static inline Vector *
|
||||
InitVector(int dim)
|
||||
{
|
||||
Vector *result;
|
||||
int size;
|
||||
|
||||
size = VECTOR_SIZE(dim);
|
||||
result = (Vector *) palloc0(size);
|
||||
SET_VARSIZE(result, size);
|
||||
result->dim = dim;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user