mirror of
https://github.com/pgvector/pgvector.git
synced 2026-07-03 11:10:56 +08:00
Use normalize_l2 for normalization
This commit is contained in:
@@ -292,4 +292,4 @@ CREATE OPERATOR CLASS vector_cosine_ops
|
||||
FOR TYPE vector USING hnsw AS
|
||||
OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 vector_negative_inner_product(vector, vector),
|
||||
FUNCTION 2 vector_norm(vector);
|
||||
FUNCTION 3 normalize_l2(vector);
|
||||
|
||||
@@ -167,7 +167,7 @@ hnswhandler(PG_FUNCTION_ARGS)
|
||||
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
|
||||
|
||||
amroutine->amstrategies = 0;
|
||||
amroutine->amsupport = 2;
|
||||
amroutine->amsupport = 3;
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
amroutine->amoptsprocnum = 0;
|
||||
#endif
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
/* Support functions */
|
||||
#define HNSW_DISTANCE_PROC 1
|
||||
#define HNSW_NORM_PROC 2
|
||||
#define HNSW_NORMALIZE_PROC 3
|
||||
|
||||
#define HNSW_VERSION 1
|
||||
#define HNSW_MAGIC_NUMBER 0xA953A953
|
||||
@@ -147,6 +148,7 @@ typedef struct HnswBuildState
|
||||
/* Support functions */
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
FmgrInfo *normalizeprocinfo;
|
||||
Oid collation;
|
||||
|
||||
/* Variables */
|
||||
@@ -220,6 +222,7 @@ typedef struct HnswScanOpaqueData
|
||||
/* Support functions */
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
FmgrInfo *normalizeprocinfo;
|
||||
Oid collation;
|
||||
} HnswScanOpaqueData;
|
||||
|
||||
@@ -255,7 +258,7 @@ typedef struct HnswVacuumState
|
||||
int HnswGetM(Relation index);
|
||||
int HnswGetEfConstruction(Relation index);
|
||||
FmgrInfo *HnswOptionalProcInfo(Relation rel, uint16 procnum);
|
||||
bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result);
|
||||
bool HnswNormValue(FmgrInfo *procinfo, FmgrInfo *normalizeprocinfo, Oid collation, Datum *value, Vector * result);
|
||||
void HnswCommitBuffer(Buffer buf, GenericXLogState *state);
|
||||
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
|
||||
void HnswInitPage(Buffer buf, Page page);
|
||||
|
||||
@@ -278,11 +278,8 @@ InsertTuple(Relation index, Datum *values, HnswElement element, HnswBuildState *
|
||||
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
|
||||
|
||||
/* Normalize if needed */
|
||||
if (buildstate->normprocinfo != NULL)
|
||||
{
|
||||
if (!HnswNormValue(buildstate->normprocinfo, collation, &value, buildstate->normvec))
|
||||
return false;
|
||||
}
|
||||
if (!HnswNormValue(buildstate->normprocinfo, buildstate->normalizeprocinfo, collation, &value, buildstate->normvec))
|
||||
return false;
|
||||
|
||||
/* Copy value to element so accessible outside of memory context */
|
||||
memcpy(element->vec, DatumGetVector(value), VECTOR_SIZE(buildstate->dimensions));
|
||||
@@ -413,6 +410,7 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
|
||||
/* Get support functions */
|
||||
buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
|
||||
buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
|
||||
buildstate->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC);
|
||||
buildstate->collation = index->rd_indcollation[0];
|
||||
|
||||
buildstate->elements = NIL;
|
||||
|
||||
@@ -417,6 +417,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti
|
||||
{
|
||||
Datum value;
|
||||
FmgrInfo *normprocinfo;
|
||||
FmgrInfo *normalizeprocinfo;
|
||||
HnswElement entryPoint;
|
||||
HnswElement element;
|
||||
int m = HnswGetM(index);
|
||||
@@ -432,11 +433,9 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti
|
||||
|
||||
/* Normalize if needed */
|
||||
normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
|
||||
if (normprocinfo != NULL)
|
||||
{
|
||||
if (!HnswNormValue(normprocinfo, collation, &value, NULL))
|
||||
return false;
|
||||
}
|
||||
normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC);
|
||||
if (!HnswNormValue(normprocinfo, normalizeprocinfo, collation, &value, NULL))
|
||||
return false;
|
||||
|
||||
/* Create an element */
|
||||
element = HnswInitElement(heap_tid, m, ml, HnswGetMaxLevel(m));
|
||||
|
||||
@@ -78,6 +78,7 @@ hnswbeginscan(Relation index, int nkeys, int norderbys)
|
||||
/* Set support functions */
|
||||
so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
|
||||
so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
|
||||
so->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC);
|
||||
so->collation = index->rd_indcollation[0];
|
||||
|
||||
scan->opaque = so;
|
||||
@@ -140,8 +141,7 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir)
|
||||
Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value)));
|
||||
|
||||
/* Fine if normalization fails */
|
||||
if (so->normprocinfo != NULL)
|
||||
HnswNormValue(so->normprocinfo, so->collation, &value, NULL);
|
||||
HnswNormValue(so->normprocinfo, so->normalizeprocinfo, so->collation, &value, NULL);
|
||||
}
|
||||
|
||||
GetScanItems(scan, value);
|
||||
|
||||
@@ -55,9 +55,20 @@ HnswOptionalProcInfo(Relation rel, uint16 procnum)
|
||||
* if it's different than the original value
|
||||
*/
|
||||
bool
|
||||
HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result)
|
||||
HnswNormValue(FmgrInfo *procinfo, FmgrInfo *normalizeprocinfo, Oid collation, Datum *value, Vector * result)
|
||||
{
|
||||
double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
|
||||
double norm;
|
||||
|
||||
if (normalizeprocinfo != NULL)
|
||||
{
|
||||
*value = FunctionCall1Coll(normalizeprocinfo, collation, *value);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (procinfo == NULL)
|
||||
return true;
|
||||
|
||||
norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
|
||||
|
||||
if (norm > 0)
|
||||
{
|
||||
|
||||
@@ -9,18 +9,19 @@ SELECT * FROM t ORDER BY val <=> '[3,3,3]';
|
||||
[1,1,1]
|
||||
[1,2,3]
|
||||
[1,2,4]
|
||||
(3 rows)
|
||||
[0,0,0]
|
||||
(4 rows)
|
||||
|
||||
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2;
|
||||
count
|
||||
-------
|
||||
3
|
||||
4
|
||||
(1 row)
|
||||
|
||||
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2;
|
||||
count
|
||||
-------
|
||||
3
|
||||
4
|
||||
(1 row)
|
||||
|
||||
DROP TABLE t;
|
||||
|
||||
Reference in New Issue
Block a user