From 21bcff672295a1be24380f80ad31af343a5814b6 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 8 Apr 2024 13:50:18 -0700 Subject: [PATCH] Added CPU dispatching for halfvec distance functions - #311 Co-authored-by: Arda Aytekin --- src/halfutils.c | 127 ++++++++++++++++++++++++++++++++++-------------- src/halfutils.h | 6 ++- src/halfvec.c | 8 +-- src/halfvec.h | 6 ++- src/vector.c | 2 + 5 files changed, 106 insertions(+), 43 deletions(-) diff --git a/src/halfutils.c b/src/halfutils.c index e961e05..c15d94f 100644 --- a/src/halfutils.c +++ b/src/halfutils.c @@ -3,24 +3,49 @@ #include "halfutils.h" #include "halfvec.h" -#ifdef F16C_SUPPORT +#ifdef HALFVEC_DISPATCH #include + +#if defined(HAVE__GET_CPUID) +#include +#elif defined(HAVE__CPUID) +#include #endif -/* - * Get the L2 squared distance between half vectors - */ -double -HalfvecL2DistanceSquared(HalfVector * a, HalfVector * b) +#ifdef _MSC_VER +#define TARGET_F16C_FMA +#else +#define TARGET_F16C_FMA __attribute__((target("f16c,fma"))) +#endif +#endif + +float (*HalfvecL2DistanceSquared) (int dim, half * ax, half * bx); +float (*HalfvecInnerProduct) (int dim, half * ax, half * bx); + +static float +HalfvecL2DistanceSquaredDefault(int dim, half * ax, half * bx) { - half *ax = a->x; - half *bx = b->x; float distance = 0.0; -#if defined(F16C_SUPPORT) && defined(__FMA__) + /* Auto-vectorized */ + for (int i = 0; i < dim; i++) + { + float diff = HalfToFloat4(ax[i]) - HalfToFloat4(bx[i]); + + distance += diff * diff; + } + + return distance; +} + +#ifdef HALFVEC_DISPATCH +TARGET_F16C_FMA static float +HalfvecL2DistanceSquaredF16cFma(int dim, half * ax, half * bx) +{ + float distance = 0.0; int i; float s[8]; - int count = (a->dim / 8) * 8; + int count = (dim / 8) * 8; __m256 dist = _mm256_setzero_ps(); for (i = 0; i < count; i += 8) @@ -38,39 +63,37 @@ HalfvecL2DistanceSquared(HalfVector * a, HalfVector * b) distance = s[0] + s[1] + s[2] + s[3] + s[4] + s[5] + s[6] + s[7]; - for (; i < a->dim; i++) + for (; i < dim; i++) { float diff = HalfToFloat4(ax[i]) - HalfToFloat4(bx[i]); distance += diff * diff; } -#else - /* Auto-vectorized */ - for (int i = 0; i < a->dim; i++) - { - float diff = HalfToFloat4(ax[i]) - HalfToFloat4(bx[i]); - distance += diff * diff; - } + return distance; +} #endif - return (double) distance; -} - -/* - * Get the inner product of two half vectors - */ -double -HalfvecInnerProduct(HalfVector * a, HalfVector * b) +static float +HalfvecInnerProductDefault(int dim, half * ax, half * bx) { - half *ax = a->x; - half *bx = b->x; float distance = 0.0; -#if defined(F16C_SUPPORT) && defined(__FMA__) + /* Auto-vectorized */ + for (int i = 0; i < dim; i++) + distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]); + + return distance; +} + +#ifdef HALFVEC_DISPATCH +TARGET_F16C_FMA static float +HalfvecInnerProductF16cFma(int dim, half * ax, half * bx) +{ + float distance = 0.0; int i; float s[8]; - int count = (a->dim / 8) * 8; + int count = (dim / 8) * 8; __m256 dist = _mm256_setzero_ps(); for (i = 0; i < count; i += 8) @@ -87,13 +110,45 @@ HalfvecInnerProduct(HalfVector * a, HalfVector * b) distance = s[0] + s[1] + s[2] + s[3] + s[4] + s[5] + s[6] + s[7]; - for (; i < a->dim; i++) - distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]); -#else - /* Auto-vectorized */ - for (int i = 0; i < a->dim; i++) + for (; i < dim; i++) distance += HalfToFloat4(ax[i]) * HalfToFloat4(bx[i]); + + return distance; +} #endif - return (double) distance; +#ifdef HALFVEC_DISPATCH +static bool +F16cFmaAvailable() +{ + unsigned int exx[4] = {0, 0, 0, 0}; + +#if defined(HAVE__GET_CPUID) + __get_cpuid(1, &exx[0], &exx[1], &exx[2], &exx[3]); +#elif defined(HAVE__CPUID) + __cpuid(exx, 1); +#endif + + /* FMA = 12, F16C = 29 */ + return (exx[2] & (1 << 12)) != 0 && (exx[2] & (1 << 29)) != 0; +} +#endif + +void +HalfvecInit(void) +{ + /* + * Could skip pointer when single function, but no difference in + * performance + */ + HalfvecL2DistanceSquared = HalfvecL2DistanceSquaredDefault; + HalfvecInnerProduct = HalfvecInnerProductDefault; + +#ifdef HALFVEC_DISPATCH + if (F16cFmaAvailable()) + { + HalfvecL2DistanceSquared = HalfvecL2DistanceSquaredF16cFma; + HalfvecInnerProduct = HalfvecInnerProductF16cFma; + } +#endif } diff --git a/src/halfutils.h b/src/halfutils.h index 75562c8..cfcb58c 100644 --- a/src/halfutils.h +++ b/src/halfutils.h @@ -3,7 +3,9 @@ #include "halfvec.h" -double HalfvecL2DistanceSquared(HalfVector * a, HalfVector * b); -double HalfvecInnerProduct(HalfVector * a, HalfVector * b); +extern float (*HalfvecL2DistanceSquared) (int dim, half * ax, half * bx); +extern float (*HalfvecInnerProduct) (int dim, half * ax, half * bx); + +void HalfvecInit(void); #endif diff --git a/src/halfvec.c b/src/halfvec.c index 7f5687b..7e10f6e 100644 --- a/src/halfvec.c +++ b/src/halfvec.c @@ -813,7 +813,7 @@ halfvec_l2_distance(PG_FUNCTION_ARGS) CheckDims(a, b); - PG_RETURN_FLOAT8(sqrt(HalfvecL2DistanceSquared(a, b))); + PG_RETURN_FLOAT8(sqrt((double) HalfvecL2DistanceSquared(a->dim, a->x, b->x))); } /* @@ -828,7 +828,7 @@ halfvec_l2_squared_distance(PG_FUNCTION_ARGS) CheckDims(a, b); - PG_RETURN_FLOAT8(HalfvecL2DistanceSquared(a, b)); + PG_RETURN_FLOAT8((double) HalfvecL2DistanceSquared(a->dim, a->x, b->x)); } /* @@ -843,7 +843,7 @@ halfvec_inner_product(PG_FUNCTION_ARGS) CheckDims(a, b); - PG_RETURN_FLOAT8(HalfvecInnerProduct(a, b)); + PG_RETURN_FLOAT8((double) HalfvecInnerProduct(a->dim, a->x, b->x)); } /* @@ -858,7 +858,7 @@ halfvec_negative_inner_product(PG_FUNCTION_ARGS) CheckDims(a, b); - PG_RETURN_FLOAT8(-HalfvecInnerProduct(a, b)); + PG_RETURN_FLOAT8((double) -HalfvecInnerProduct(a->dim, a->x, b->x)); } /* diff --git a/src/halfvec.h b/src/halfvec.h index e0ab7c6..868a337 100644 --- a/src/halfvec.h +++ b/src/halfvec.h @@ -7,10 +7,14 @@ #include "vector.h" +#if defined(__x86_64__) || defined(_M_AMD64) +#define HALFVEC_DISPATCH +#endif + /* F16C has better performance than _Float16 (on x86-64) */ #if defined(__F16C__) #define F16C_SUPPORT -#elif defined(__FLT16_MAX__) +#elif defined(__FLT16_MAX__) && !defined(HALFVEC_DISPATCH) #define FLT16_SUPPORT #endif diff --git a/src/vector.c b/src/vector.c index 31e7386..7032b3e 100644 --- a/src/vector.c +++ b/src/vector.c @@ -6,6 +6,7 @@ #include "catalog/pg_type.h" #include "common/shortest_dec.h" #include "fmgr.h" +#include "halfutils.h" #include "halfvec.h" #include "hnsw.h" #include "ivfflat.h" @@ -41,6 +42,7 @@ PGDLLEXPORT void _PG_init(void); void _PG_init(void) { + HalfvecInit(); HnswInit(); IvfflatInit(); }