35 #ifndef NEKTAR_LIB_LIBUTILITES_SIMDLIB_AVX512_H
36 #define NEKTAR_LIB_LIBUTILITES_SIMDLIB_AVX512_H
38 #include <immintrin.h>
50 template <
typename scalarType>
59 #if defined(__AVX512F__) && defined(NEKTAR_ENABLE_SIMD_AVX512)
62 template<
typename T>
struct avx512Long8;
69 template <>
struct avx512<double> {
using type = avx512Double8; };
70 template <>
struct avx512<std::int64_t> {
using type = avx512Long8<std::int64_t>; };
71 template <>
struct avx512<std::uint64_t> {
using type = avx512Long8<std::uint64_t>; };
72 template <>
struct avx512<bool> {
using type = avx512Mask; };
80 static_assert(std::is_integral<T>::value &&
sizeof(T) == 8,
81 "8 bytes Integral required.");
83 static constexpr
unsigned int width = 8;
84 static constexpr
unsigned int alignment = 64;
87 using vectorType = __m512i;
88 using scalarArray = scalarType[width];
94 inline avx512Long8() =
default;
95 inline avx512Long8(
const avx512Long8& rhs) =
default;
96 inline avx512Long8(
const vectorType& rhs) : _data(rhs){}
97 inline avx512Long8(
const scalarType rhs)
99 _data = _mm512_set1_epi64(rhs);
101 explicit inline avx512Long8(scalarArray& rhs)
103 _data = _mm512_load_epi64(rhs);
107 inline void store(scalarType*
p)
const
109 _mm512_store_epi64(
p, _data);
115 typename std::enable_if<
116 is_requiring_alignment<flag>::value &&
117 !is_streaming<flag>::value,
bool
120 inline void store(scalarType*
p, flag)
const
122 _mm512_store_epi64(
p, _data);
128 typename std::enable_if<
129 !is_requiring_alignment<flag>::value,
bool
132 inline void store(scalarType*
p, flag)
const
134 _mm512_storeu_epi64(
p, _data);
137 inline void load(
const scalarType*
p)
139 _data = _mm512_load_epi64(
p);
145 typename std::enable_if<
146 is_requiring_alignment<flag>::value &&
147 !is_streaming<flag>::value,
bool
150 inline void load(
const scalarType*
p, flag)
152 _data = _mm512_load_epi64(
p);
158 typename std::enable_if<
159 !is_requiring_alignment<flag>::value,
bool
162 inline void load(
const scalarType*
p, flag)
164 _data = _mm512_loadu_epi64(
p);
167 inline void broadcast(
const scalarType rhs)
169 _data = _mm512_set1_epi64(rhs);
175 inline scalarType operator[](
size_t i)
const
177 alignas(alignment) scalarArray tmp;
182 inline scalarType& operator[](
size_t i)
184 scalarType* tmp =
reinterpret_cast<scalarType*
>(&_data);
191 inline avx512Long8<T>
operator+(avx512Long8<T> lhs, avx512Long8<T> rhs)
193 return _mm512_add_epi64(lhs._data, rhs._data);
196 template<
typename T,
typename U,
typename =
typename std::enable_if<
197 std::is_arithmetic<U>::value>::type>
198 inline avx512Long8<T>
operator+(avx512Long8<T> lhs, U rhs)
200 return _mm512_add_epi64(lhs._data, _mm512_set1_epi64(rhs));
207 static constexpr
unsigned int width = 8;
208 static constexpr
unsigned int alignment = 64;
210 using scalarType = double;
211 using vectorType = __m512d;
212 using scalarArray = scalarType[width];
218 inline avx512Double8() =
default;
219 inline avx512Double8(
const avx512Double8& rhs) =
default;
220 inline avx512Double8(
const vectorType& rhs) : _data(rhs){}
221 inline avx512Double8(
const scalarType rhs)
223 _data = _mm512_set1_pd(rhs);
227 inline void store(scalarType*
p)
const
229 _mm512_store_pd(
p, _data);
235 typename std::enable_if<
236 is_requiring_alignment<flag>::value &&
237 !is_streaming<flag>::value,
bool
240 inline void store(scalarType*
p, flag)
const
242 _mm512_store_pd(
p, _data);
248 typename std::enable_if<
249 !is_requiring_alignment<flag>::value,
bool
252 inline void store(scalarType*
p, flag)
const
254 _mm512_storeu_pd(
p, _data);
260 typename std::enable_if<
261 is_streaming<flag>::value,
bool
264 inline void store(scalarType*
p, flag)
const
266 _mm512_stream_pd(
p, _data);
270 inline void load(
const scalarType*
p)
272 _data = _mm512_load_pd(
p);
278 typename std::enable_if<
279 is_requiring_alignment<flag>::value,
bool
282 inline void load(
const scalarType*
p, flag)
284 _data = _mm512_load_pd(
p);
290 typename std::enable_if<
291 !is_requiring_alignment<flag>::value,
bool
294 inline void load(
const scalarType*
p, flag)
296 _data = _mm512_loadu_pd(
p);
300 inline void broadcast(
const scalarType rhs)
302 _data = _mm512_set1_pd(rhs);
307 template <
typename T>
308 inline void gather(scalarType
const*
p,
const avx2Int8<T>& indices)
310 _data = _mm512_i32gather_pd(
p, indices._data, 8);
313 template <
typename T>
314 inline void scatter(scalarType* out,
const avx2Int8<T>& indices)
const
316 _mm512_i32scatter_pd(out, indices._data, 8);
319 template <
typename T>
320 inline void gather(scalarType
const*
p,
const avx512Long8<T>& indices)
322 _data = _mm512_i64gather_pd(
p, indices._data, 8);
325 template <
typename T>
326 inline void scatter(scalarType* out,
const avx512Long8<T>& indices)
const
328 _mm512_i64scatter_pd(out, indices._data, 8);
333 inline void fma(
const avx512Double8& a,
const avx512Double8& b)
335 _data = _mm512_fmadd_pd(a._data, b._data, _data);
342 inline scalarType operator[](
size_t i)
const
344 alignas(alignment) scalarArray tmp;
349 inline scalarType& operator[](
size_t i)
351 scalarType* tmp =
reinterpret_cast<scalarType*
>(&_data);
356 inline void operator+=(avx512Double8 rhs)
358 _data = _mm512_add_pd(_data, rhs._data);
361 inline void operator-=(avx512Double8 rhs)
363 _data = _mm512_sub_pd(_data, rhs._data);
366 inline void operator*=(avx512Double8 rhs)
368 _data = _mm512_mul_pd(_data, rhs._data);
371 inline void operator/=(avx512Double8 rhs)
373 _data = _mm512_div_pd(_data, rhs._data);
378 inline avx512Double8
operator+(avx512Double8 lhs, avx512Double8 rhs)
380 return _mm512_add_pd(lhs._data, rhs._data);
383 inline avx512Double8
operator-(avx512Double8 lhs, avx512Double8 rhs)
385 return _mm512_sub_pd(lhs._data, rhs._data);
388 inline avx512Double8
operator*(avx512Double8 lhs, avx512Double8 rhs)
390 return _mm512_mul_pd(lhs._data, rhs._data);
393 inline avx512Double8
operator/(avx512Double8 lhs, avx512Double8 rhs)
395 return _mm512_div_pd(lhs._data, rhs._data);
398 inline avx512Double8
sqrt(avx512Double8 in)
400 return _mm512_sqrt_pd(in._data);
403 inline avx512Double8
abs(avx512Double8 in)
405 return _mm512_abs_pd(in._data);
408 inline avx512Double8
log(avx512Double8 in)
412 alignas(avx512Double8::alignment) avx512Double8::scalarArray tmp;
430 std::vector<avx512Double8, allocator<avx512Double8>> &out)
433 alignas(avx512Double8::alignment)
size_t tmp[avx512Double8::width] =
434 {0, dataLen, 2*dataLen, 3*dataLen, 4*dataLen, 5*dataLen, 6*dataLen,
437 using index_t = avx512Long8<size_t>;
439 index_t index1 = index0 + 1;
440 index_t index2 = index0 + 2;
441 index_t index3 = index0 + 3;
444 size_t nBlocks = dataLen / 4;
445 for (
size_t i = 0; i < nBlocks; ++i)
447 out[4*i + 0].gather(in, index0);
448 out[4*i + 1].gather(in, index1);
449 out[4*i + 2].gather(in, index2);
450 out[4*i + 3].gather(in, index3);
458 for (
size_t i = 4 * nBlocks; i < dataLen; ++i)
460 out[i].gather(in, index0);
467 const std::vector<avx512Double8, allocator<avx512Double8>> &in,
473 double *out1 = out + dataLen;
474 double *out2 = out + 2 * dataLen;
475 double *out3 = out + 3 * dataLen;
476 double *out4 = out + 4 * dataLen;
477 double *out5 = out + 5 * dataLen;
478 double *out6 = out + 6 * dataLen;
479 double *out7 = out + 7 * dataLen;
482 for (
size_t i = 0; i < dataLen; ++i)
497 alignas(avx512Double8::alignment)
size_t tmp[avx512Double8::width] =
498 {0, dataLen, 2*dataLen, 3*dataLen, 4*dataLen, 5*dataLen, 6*dataLen,
500 using index_t = avx512Long8<size_t>;
526 for (
size_t i = 0; i < dataLen; ++i)
528 in[i].scatter(out, index0);
545 static constexpr
unsigned int width = 1;
546 static constexpr
unsigned int alignment = 8;
548 using scalarType = bool;
549 using vectorType = __mmask8;
550 using scalarArray = std::uint8_t;
556 static constexpr scalarType true_v =
true;
557 static constexpr scalarType false_v =
false;
560 inline avx512Mask() =
default;
561 inline avx512Mask(
const avx512Mask& rhs) =
default;
562 inline avx512Mask(
const vectorType& rhs) : _data(rhs){}
563 inline avx512Mask(
const scalarType rhs)
565 _data = _mm512_set1_epi64(rhs);
567 explicit inline avx512Mask(scalarArray& rhs)
569 _data = _mm512_load_epi64(rhs);
573 inline void load(scalarArray*
p)
const
575 _load_mask8(
reinterpret_cast<vectorType*
>(
p), _data);
580 inline avx512Mask
operator>(avx512Double8 lhs, avx512Double8 rhs)
583 return _mm512_cmp_pd_mask(rhs._data, lhs._data, 1);
586 inline bool operator&&(avx512Mask lhs,
bool rhs)
588 static constexpr std::uint8_t mask_true = 0xFF;
589 bool tmp = _ktestc_mask8_u8(lhs._data, _load_mask8(&mask_true));
scalarT< T > log(scalarT< T > in)
scalarT< T > operator+(scalarT< T > lhs, scalarT< T > rhs)
void deinterleave_store(const std::vector< scalarT< T >, allocator< scalarT< T >>> &in, size_t dataLen, T *out)
static constexpr struct tinysimd::is_aligned_t is_aligned
scalarT< T > operator-(scalarT< T > lhs, scalarT< T > rhs)
scalarT< T > abs(scalarT< T > in)
scalarMask operator>(scalarT< double > lhs, scalarT< double > rhs)
void load_interleave(const T *in, size_t dataLen, std::vector< scalarT< T >, allocator< scalarT< T >>> &out)
bool operator&&(scalarMask lhs, bool rhs)
scalarT< T > sqrt(scalarT< T > in)
scalarT< T > operator/(scalarT< T > lhs, scalarT< T > rhs)
scalarT< T > operator*(scalarT< T > lhs, scalarT< T > rhs)