Use normalize_l2 for normalization

This commit is contained in:
Andrew Kane
2023-10-16 16:42:40 -07:00
parent 9ed7e63fb7
commit 0054a9c40a
8 changed files with 32 additions and 20 deletions

View File

@@ -292,4 +292,4 @@ CREATE OPERATOR CLASS vector_cosine_ops
FOR TYPE vector USING hnsw AS
OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector),
FUNCTION 2 vector_norm(vector);
FUNCTION 3 normalize_l2(vector);

View File

@@ -167,7 +167,7 @@ hnswhandler(PG_FUNCTION_ARGS)
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
amroutine->amstrategies = 0;
amroutine->amsupport = 2;
amroutine->amsupport = 3;
#if PG_VERSION_NUM >= 130000
amroutine->amoptsprocnum = 0;
#endif

View File

@@ -19,6 +19,7 @@
/* Support functions */
#define HNSW_DISTANCE_PROC 1
#define HNSW_NORM_PROC 2
#define HNSW_NORMALIZE_PROC 3
#define HNSW_VERSION 1
#define HNSW_MAGIC_NUMBER 0xA953A953
@@ -147,6 +148,7 @@ typedef struct HnswBuildState
/* Support functions */
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
FmgrInfo *normalizeprocinfo;
Oid collation;
/* Variables */
@@ -220,6 +222,7 @@ typedef struct HnswScanOpaqueData
/* Support functions */
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
FmgrInfo *normalizeprocinfo;
Oid collation;
} HnswScanOpaqueData;
@@ -255,7 +258,7 @@ typedef struct HnswVacuumState
int HnswGetM(Relation index);
int HnswGetEfConstruction(Relation index);
FmgrInfo *HnswOptionalProcInfo(Relation rel, uint16 procnum);
bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result);
bool HnswNormValue(FmgrInfo *procinfo, FmgrInfo *normalizeprocinfo, Oid collation, Datum *value, Vector * result);
void HnswCommitBuffer(Buffer buf, GenericXLogState *state);
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
void HnswInitPage(Buffer buf, Page page);

View File

@@ -278,11 +278,8 @@ InsertTuple(Relation index, Datum *values, HnswElement element, HnswBuildState *
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
/* Normalize if needed */
if (buildstate->normprocinfo != NULL)
{
if (!HnswNormValue(buildstate->normprocinfo, collation, &value, buildstate->normvec))
return false;
}
if (!HnswNormValue(buildstate->normprocinfo, buildstate->normalizeprocinfo, collation, &value, buildstate->normvec))
return false;
/* Copy value to element so accessible outside of memory context */
memcpy(element->vec, DatumGetVector(value), VECTOR_SIZE(buildstate->dimensions));
@@ -413,6 +410,7 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
/* Get support functions */
buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
buildstate->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC);
buildstate->collation = index->rd_indcollation[0];
buildstate->elements = NIL;

View File

@@ -417,6 +417,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti
{
Datum value;
FmgrInfo *normprocinfo;
FmgrInfo *normalizeprocinfo;
HnswElement entryPoint;
HnswElement element;
int m = HnswGetM(index);
@@ -432,11 +433,9 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti
/* Normalize if needed */
normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
if (normprocinfo != NULL)
{
if (!HnswNormValue(normprocinfo, collation, &value, NULL))
return false;
}
normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC);
if (!HnswNormValue(normprocinfo, normalizeprocinfo, collation, &value, NULL))
return false;
/* Create an element */
element = HnswInitElement(heap_tid, m, ml, HnswGetMaxLevel(m));

View File

@@ -78,6 +78,7 @@ hnswbeginscan(Relation index, int nkeys, int norderbys)
/* Set support functions */
so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
so->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC);
so->collation = index->rd_indcollation[0];
scan->opaque = so;
@@ -140,8 +141,7 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir)
Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value)));
/* Fine if normalization fails */
if (so->normprocinfo != NULL)
HnswNormValue(so->normprocinfo, so->collation, &value, NULL);
HnswNormValue(so->normprocinfo, so->normalizeprocinfo, so->collation, &value, NULL);
}
GetScanItems(scan, value);

View File

@@ -55,9 +55,20 @@ HnswOptionalProcInfo(Relation rel, uint16 procnum)
* if it's different than the original value
*/
bool
HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result)
HnswNormValue(FmgrInfo *procinfo, FmgrInfo *normalizeprocinfo, Oid collation, Datum *value, Vector * result)
{
double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
double norm;
if (normalizeprocinfo != NULL)
{
*value = FunctionCall1Coll(normalizeprocinfo, collation, *value);
return true;
}
if (procinfo == NULL)
return true;
norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
if (norm > 0)
{

View File

@@ -9,18 +9,19 @@ SELECT * FROM t ORDER BY val <=> '[3,3,3]';
[1,1,1]
[1,2,3]
[1,2,4]
(3 rows)
[0,0,0]
(4 rows)
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2;
count
-------
3
4
(1 row)
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2;
count
-------
3
4
(1 row)
DROP TABLE t;