diff --git a/src/halfutils.c b/src/halfutils.c index c53da83..7485e9d 100644 --- a/src/halfutils.c +++ b/src/halfutils.c @@ -118,8 +118,11 @@ HalfvecInnerProductF16cFma(int dim, half * ax, half * bx) #endif #ifdef HALFVEC_DISPATCH +#define FEATURE_FMA (1 << 12) +#define FEATURE_F16C (1 << 29) + static bool -F16cFmaAvailable() +SupportsFeature(unsigned int feature) { unsigned int exx[4] = {0, 0, 0, 0}; @@ -129,8 +132,7 @@ F16cFmaAvailable() __cpuid(exx, 1); #endif - /* FMA = 12, F16C = 29 */ - return (exx[2] & (1 << 12)) != 0 && (exx[2] & (1 << 29)) != 0; + return (exx[2] & feature) == feature; } #endif @@ -145,7 +147,7 @@ HalfvecInit(void) HalfvecInnerProduct = HalfvecInnerProductDefault; #ifdef HALFVEC_DISPATCH - if (F16cFmaAvailable()) + if (SupportsFeature(FEATURE_FMA | FEATURE_F16C)) { HalfvecL2DistanceSquared = HalfvecL2DistanceSquaredF16cFma; HalfvecInnerProduct = HalfvecInnerProductF16cFma;