From 55845bfd5f560b0d21d3338f5adfab2ac95b2d8e Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 15 Apr 2024 10:01:05 -0700 Subject: [PATCH] Added SIMD version of cosine distance --- src/halfutils.c | 76 +++++++++++++++++++++++++++++++++++++++++++++++++ src/halfutils.h | 1 + src/halfvec.c | 19 +------------ 3 files changed, 78 insertions(+), 18 deletions(-) diff --git a/src/halfutils.c b/src/halfutils.c index 1123a1b..44d91e9 100644 --- a/src/halfutils.c +++ b/src/halfutils.c @@ -21,6 +21,7 @@ float (*HalfvecL2SquaredDistance) (int dim, half * ax, half * bx); float (*HalfvecInnerProduct) (int dim, half * ax, half * bx); +double (*HalfvecCosineSimilarity) (int dim, half * ax, half * bx); static float HalfvecL2SquaredDistanceDefault(int dim, half * ax, half * bx) @@ -117,6 +118,79 @@ HalfvecInnerProductF16cFma(int dim, half * ax, half * bx) } #endif +static double +HalfvecCosineSimilarityDefault(int dim, half * ax, half * bx) +{ + float similarity = 0.0; + float norma = 0.0; + float normb = 0.0; + + /* Auto-vectorized */ + for (int i = 0; i < dim; i++) + { + float axi = HalfToFloat4(ax[i]); + float bxi = HalfToFloat4(bx[i]); + + similarity += axi * bxi; + norma += axi * axi; + normb += bxi * bxi; + } + + /* Use sqrt(a * b) over sqrt(a) * sqrt(b) */ + return (double) similarity / sqrt((double) norma * (double) normb); +} + +#ifdef HALFVEC_DISPATCH +TARGET_F16C_FMA static double +HalfvecCosineSimilarityF16cFma(int dim, half * ax, half * bx) +{ + float similarity; + float norma; + float normb; + int i; + float s[8]; + int count = (dim / 8) * 8; + __m256 sim = _mm256_setzero_ps(); + __m256 na = _mm256_setzero_ps(); + __m256 nb = _mm256_setzero_ps(); + + for (i = 0; i < count; i += 8) + { + __m128i axi = _mm_loadu_si128((__m128i *) (ax + i)); + __m128i bxi = _mm_loadu_si128((__m128i *) (bx + i)); + __m256 axs = _mm256_cvtph_ps(axi); + __m256 bxs = _mm256_cvtph_ps(bxi); + + sim = _mm256_fmadd_ps(axs, bxs, sim); + na = _mm256_fmadd_ps(axs, axs, na); + nb = _mm256_fmadd_ps(bxs, bxs, nb); + } + + _mm256_storeu_ps(s, sim); + similarity = s[0] + s[1] + s[2] + s[3] + s[4] + s[5] + s[6] + s[7]; + + _mm256_storeu_ps(s, na); + norma = s[0] + s[1] + s[2] + s[3] + s[4] + s[5] + s[6] + s[7]; + + _mm256_storeu_ps(s, nb); + normb = s[0] + s[1] + s[2] + s[3] + s[4] + s[5] + s[6] + s[7]; + + /* Auto-vectorized */ + for (; i < dim; i++) + { + float axi = HalfToFloat4(ax[i]); + float bxi = HalfToFloat4(bx[i]); + + similarity += axi * bxi; + norma += axi * axi; + normb += bxi * bxi; + } + + /* Use sqrt(a * b) over sqrt(a) * sqrt(b) */ + return (double) similarity / sqrt((double) norma * (double) normb); +} +#endif + #ifdef HALFVEC_DISPATCH #define CPU_FEATURE_FMA (1 << 12) #define CPU_FEATURE_F16C (1 << 29) @@ -145,12 +219,14 @@ HalfvecInit(void) */ HalfvecL2SquaredDistance = HalfvecL2SquaredDistanceDefault; HalfvecInnerProduct = HalfvecInnerProductDefault; + HalfvecCosineSimilarity = HalfvecCosineSimilarityDefault; #ifdef HALFVEC_DISPATCH if (SupportsCpuFeature(CPU_FEATURE_FMA | CPU_FEATURE_F16C)) { HalfvecL2SquaredDistance = HalfvecL2SquaredDistanceF16cFma; HalfvecInnerProduct = HalfvecInnerProductF16cFma; + HalfvecCosineSimilarity = HalfvecCosineSimilarityF16cFma; } #endif } diff --git a/src/halfutils.h b/src/halfutils.h index 5881315..959d96e 100644 --- a/src/halfutils.h +++ b/src/halfutils.h @@ -12,6 +12,7 @@ extern float (*HalfvecL2SquaredDistance) (int dim, half * ax, half * bx); extern float (*HalfvecInnerProduct) (int dim, half * ax, half * bx); +extern double (*HalfvecCosineSimilarity) (int dim, half * ax, half * bx); void HalfvecInit(void); diff --git a/src/halfvec.c b/src/halfvec.c index 72edbc5..d57ed28 100644 --- a/src/halfvec.c +++ b/src/halfvec.c @@ -626,28 +626,11 @@ halfvec_cosine_distance(PG_FUNCTION_ARGS) { HalfVector *a = PG_GETARG_HALFVEC_P(0); HalfVector *b = PG_GETARG_HALFVEC_P(1); - half *ax = a->x; - half *bx = b->x; - float distance = 0.0; - float norma = 0.0; - float normb = 0.0; double similarity; CheckDims(a, b); - /* Auto-vectorized */ - for (int i = 0; i < a->dim; i++) - { - float axi = HalfToFloat4(ax[i]); - float bxi = HalfToFloat4(bx[i]); - - distance += axi * bxi; - norma += axi * axi; - normb += bxi * bxi; - } - - /* Use sqrt(a * b) over sqrt(a) * sqrt(b) */ - similarity = (double) distance / sqrt((double) norma * (double) normb); + similarity = HalfvecCosineSimilarity(a->dim, a->x, b->x); #ifdef _MSC_VER /* /fp:fast may not propagate NaN */