Added casting between vector and halfvec

This commit is contained in:
Andrew Kane
2024-03-28 14:04:59 -07:00
parent 45ef8f8a45
commit 4f1a379638
6 changed files with 78 additions and 0 deletions

View File

@@ -114,3 +114,15 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops,
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
FUNCTION 2 halfvec_norm(halfvec);
CREATE FUNCTION halfvec_to_vector(halfvec, integer, boolean) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION vector_to_halfvec(vector, integer, boolean) RETURNS halfvec
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE CAST (halfvec AS vector)
WITH FUNCTION halfvec_to_vector(halfvec, integer, boolean) AS ASSIGNMENT;
CREATE CAST (vector AS halfvec)
WITH FUNCTION vector_to_halfvec(vector, integer, boolean) AS ASSIGNMENT;

View File

@@ -415,3 +415,17 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops,
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
FUNCTION 2 halfvec_norm(halfvec);
-- extension casts
CREATE FUNCTION halfvec_to_vector(halfvec, integer, boolean) RETURNS vector
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION vector_to_halfvec(vector, integer, boolean) RETURNS halfvec
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE CAST (halfvec AS vector)
WITH FUNCTION halfvec_to_vector(halfvec, integer, boolean) AS ASSIGNMENT;
CREATE CAST (vector AS halfvec)
WITH FUNCTION vector_to_halfvec(vector, integer, boolean) AS ASSIGNMENT;

View File

@@ -14,6 +14,7 @@
#include "utils/float.h"
#include "utils/lsyscache.h"
#include "utils/numeric.h"
#include "vector.h"
#if PG_VERSION_NUM < 130000
#define TYPALIGN_DOUBLE 'd'
@@ -722,6 +723,24 @@ halfvec_to_float4(PG_FUNCTION_ARGS)
PG_RETURN_POINTER(result);
}
/*
* Convert half vector to vector
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_to_vector);
Datum
halfvec_to_vector(PG_FUNCTION_ARGS)
{
HalfVector *vec = PG_GETARG_HALFVEC_P(0);
/* TODO Check vector dims in InitVector */
Vector *result = InitVector(vec->dim);
for (int i = 0; i < vec->dim; i++)
result->x[i] = HalfToFloat4(vec->x[i]);
PG_RETURN_POINTER(result);
}
/*
* Get the L2 distance between half vectors
*/

View File

@@ -5,6 +5,7 @@
#include "catalog/pg_type.h"
#include "common/shortest_dec.h"
#include "fmgr.h"
#include "halfvec.h"
#include "hnsw.h"
#include "ivfflat.h"
#include "lib/stringinfo.h"
@@ -532,6 +533,24 @@ vector_to_float4(PG_FUNCTION_ARGS)
PG_RETURN_POINTER(result);
}
/*
* Convert vector to half vec
*/
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_to_halfvec);
Datum
vector_to_halfvec(PG_FUNCTION_ARGS)
{
Vector *vec = PG_GETARG_VECTOR_P(0);
/* TODO Check halfvec dims in InitHalfVector */
HalfVector *result = InitHalfVector(vec->dim);
for (int i = 0; i < vec->dim; i++)
result->x[i] = Float4ToHalf(vec->x[i]);
PG_RETURN_POINTER(result);
}
/*
* Get the L2 distance between vectors
*/

View File

@@ -46,6 +46,18 @@ SELECT '[1,2,3]'::vector::real[];
{1,2,3}
(1 row)
SELECT '[1,2,3]'::vector::halfvec;
halfvec
---------
[1,2,3]
(1 row)
SELECT '[1,2,3]'::halfvec::vector;
vector
---------
[1,2,3]
(1 row)
SELECT array_agg(n)::vector FROM generate_series(1, 16001) n;
ERROR: vector cannot have more than 16000 dimensions
SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n;

View File

@@ -10,6 +10,8 @@ SELECT '{-Infinity}'::real[]::vector;
SELECT '{}'::real[]::vector;
SELECT '{{1}}'::real[]::vector;
SELECT '[1,2,3]'::vector::real[];
SELECT '[1,2,3]'::vector::halfvec;
SELECT '[1,2,3]'::halfvec::vector;
SELECT array_agg(n)::vector FROM generate_series(1, 16001) n;
SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n;