35 #ifndef NEKTAR_LIB_LIBUTILITES_SIMDLIB_AVX2_H
36 #define NEKTAR_LIB_LIBUTILITES_SIMDLIB_AVX2_H
38 #include <immintrin.h>
50 template <
typename scalarType>
59 #if defined(__AVX2__) && defined(NEKTAR_ENABLE_SIMD_AVX2)
62 template<
typename T>
struct avx2Int8;
63 template<
typename T>
struct avx2Long4;
71 template <>
struct avx2<double> {
using type = avx2Double4; };
72 template <>
struct avx2<std::int64_t> {
using type = avx2Long4<std::int64_t>; };
73 template <>
struct avx2<std::uint64_t> {
using type = avx2Long4<std::uint64_t>; };
74 template <>
struct avx2<bool> {
using type = avx2Mask; };
75 #if defined(__AVX512F__) && defined(NEKTAR_ENABLE_SIMD_AVX512)
80 template <>
struct avx2<std::int32_t> {
using type = avx2Int8<std::int32_t>; };
81 template <>
struct avx2<std::uint32_t> {
using type = avx2Int8<std::uint32_t>; };
90 static_assert(std::is_integral<T>::value &&
sizeof(T) == 4,
91 "4 bytes Integral required.");
93 static constexpr
unsigned int width = 8;
94 static constexpr
unsigned int alignment = 32;
97 using vectorType = __m256i;
98 using scalarArray = scalarType[width];
104 inline avx2Int8() =
default;
105 inline avx2Int8(
const avx2Int8& rhs) =
default;
106 inline avx2Int8(
const vectorType& rhs) : _data(rhs){}
107 inline avx2Int8(
const scalarType rhs)
109 _data = _mm256_set1_epi32(rhs);
111 explicit inline avx2Int8(scalarArray& rhs)
113 _data = _mm256_load_si256(
reinterpret_cast<vectorType*
>(rhs));
117 inline void store(scalarType*
p)
const
119 _mm256_store_si256(
reinterpret_cast<vectorType*
>(
p), _data);
125 typename std::enable_if<
126 is_requiring_alignment<flag>::value &&
127 !is_streaming<flag>::value,
bool
130 inline void store(scalarType*
p, flag)
const
132 _mm256_store_si256(
reinterpret_cast<vectorType*
>(
p), _data);
138 typename std::enable_if<
139 !is_requiring_alignment<flag>::value,
bool
142 inline void store(scalarType*
p, flag)
const
144 _mm256_storeu_si256(
reinterpret_cast<vectorType*
>(
p), _data);
147 inline void load(
const scalarType*
p)
149 _data = _mm256_load_si256(
reinterpret_cast<const vectorType*
>(
p));
155 typename std::enable_if<
156 is_requiring_alignment<flag>::value &&
157 !is_streaming<flag>::value,
bool
160 inline void load(
const scalarType*
p, flag)
162 _data = _mm256_load_si256(
reinterpret_cast<const vectorType*
>(
p));
168 typename std::enable_if<
169 !is_requiring_alignment<flag>::value,
bool
172 inline void load(
const scalarType*
p, flag)
174 _data = _mm256_loadu_si256(
reinterpret_cast<const vectorType*
>(
p));
177 inline void broadcast(
const scalarType rhs)
179 _data = _mm256_set1_epi32(rhs);
185 inline scalarType operator[](
size_t i)
const
187 alignas(alignment) scalarArray tmp;
192 inline scalarType& operator[](
size_t i)
194 scalarType* tmp =
reinterpret_cast<scalarType*
>(&_data);
201 inline avx2Int8<T>
operator+(avx2Int8<T> lhs, avx2Int8<T> rhs)
203 return _mm256_add_epi32(lhs._data, rhs._data);
206 template<
typename T,
typename U,
typename =
typename std::enable_if<
207 std::is_arithmetic<U>::value>::type>
208 inline avx2Int8<T>
operator+(avx2Int8<T> lhs, U rhs)
210 return _mm256_add_epi32(lhs._data, _mm256_set1_epi32(rhs));
218 static_assert(std::is_integral<T>::value &&
sizeof(T) == 8,
219 "8 bytes Integral required.");
221 static constexpr
unsigned int width = 4;
222 static constexpr
unsigned int alignment = 32;
224 using scalarType = T;
225 using vectorType = __m256i;
226 using scalarArray = scalarType[width];
232 inline avx2Long4() =
default;
233 inline avx2Long4(
const avx2Long4& rhs) =
default;
234 inline avx2Long4(
const vectorType& rhs) : _data(rhs){}
235 inline avx2Long4(
const scalarType rhs)
237 _data = _mm256_set1_epi64x(rhs);
239 explicit inline avx2Long4(scalarArray& rhs)
241 _data = _mm256_load_si256(
reinterpret_cast<vectorType*
>(rhs));
245 inline void store(scalarType*
p)
const
247 _mm256_store_si256(
reinterpret_cast<vectorType*
>(
p), _data);
253 typename std::enable_if<
254 is_requiring_alignment<flag>::value &&
255 !is_streaming<flag>::value,
bool
258 inline void store(scalarType*
p, flag)
const
260 _mm256_store_si256(
reinterpret_cast<vectorType*
>(
p), _data);
266 typename std::enable_if<
267 !is_requiring_alignment<flag>::value,
bool
270 inline void store(scalarType*
p, flag)
const
272 _mm256_storeu_si256(
reinterpret_cast<vectorType*
>(
p), _data);
275 inline void load(
const scalarType*
p)
277 _data = _mm256_load_si256(
reinterpret_cast<const vectorType*
>(
p));
283 typename std::enable_if<
284 is_requiring_alignment<flag>::value &&
285 !is_streaming<flag>::value,
bool
288 inline void load(
const scalarType*
p, flag)
290 _data = _mm256_load_si256(
reinterpret_cast<const vectorType*
>(
p));
296 typename std::enable_if<
297 !is_requiring_alignment<flag>::value,
bool
300 inline void load(
const scalarType*
p, flag)
302 _data = _mm256_loadu_si256(
reinterpret_cast<const vectorType*
>(
p));
305 inline void broadcast(
const scalarType rhs)
307 _data = _mm256_set1_epi64x(rhs);
313 inline scalarType operator[](
size_t i)
const
315 alignas(alignment) scalarArray tmp;
320 inline scalarType& operator[](
size_t i)
322 scalarType* tmp =
reinterpret_cast<scalarType*
>(&_data);
329 inline avx2Long4<T>
operator+(avx2Long4<T> lhs, avx2Long4<T> rhs)
331 return _mm256_add_epi64(lhs._data, rhs._data);
334 template<
typename T,
typename U,
typename =
typename std::enable_if<
335 std::is_arithmetic<U>::value>::type>
336 inline avx2Long4<T>
operator+(avx2Long4<T> lhs, U rhs)
338 return _mm256_add_epi64(lhs._data, _mm256_set1_epi64x(rhs));
345 static constexpr
unsigned width = 4;
346 static constexpr
unsigned alignment = 32;
348 using scalarType = double;
349 using vectorType = __m256d;
350 using scalarArray = scalarType[width];
356 inline avx2Double4() =
default;
357 inline avx2Double4(
const avx2Double4& rhs) =
default;
358 inline avx2Double4(
const vectorType& rhs) : _data(rhs){}
359 inline avx2Double4(
const scalarType rhs)
361 _data = _mm256_set1_pd(rhs);
365 inline void store(scalarType*
p)
const
367 _mm256_store_pd(
p, _data);
373 typename std::enable_if<
374 is_requiring_alignment<flag>::value &&
375 !is_streaming<flag>::value,
bool
378 inline void store(scalarType*
p, flag)
const
380 _mm256_store_pd(
p, _data);
386 typename std::enable_if<
387 !is_requiring_alignment<flag>::value,
bool
390 inline void store(scalarType*
p, flag)
const
392 _mm256_storeu_pd(
p, _data);
398 typename std::enable_if<
399 is_streaming<flag>::value,
bool
402 inline void store(scalarType*
p, flag)
const
404 _mm256_stream_pd(
p, _data);
408 inline void load(
const scalarType*
p)
410 _data = _mm256_load_pd(
p);
416 typename std::enable_if<
417 is_requiring_alignment<flag>::value,
bool
420 inline void load(
const scalarType*
p, flag)
422 _data = _mm256_load_pd(
p);
428 typename std::enable_if<
429 !is_requiring_alignment<flag>::value,
bool
432 inline void load(
const scalarType*
p, flag)
434 _data = _mm256_loadu_pd(
p);
438 inline void load(scalarType
const* a, scalarType
const* b,
439 scalarType
const* c, scalarType
const* d)
441 __m128d t1, t2, t3, t4;
444 t2 = _mm_loadh_pd(t1, b);
446 t4 = _mm_loadh_pd(t3, d);
447 t5 = _mm256_castpd128_pd256(t2);
448 _data = _mm256_insertf128_pd(t5, t4, 1);
452 inline void broadcast(
const scalarType rhs)
454 _data = _mm256_set1_pd(rhs);
458 template <
typename T>
459 inline void gather(scalarType
const*
p,
const sse2Int4<T>& indices)
461 _data = _mm256_i32gather_pd(
p, indices._data, 8);
464 template <
typename T>
465 inline void scatter(scalarType* out,
const sse2Int4<T>& indices)
const
468 alignas(alignment) scalarArray tmp;
469 _mm256_store_pd(tmp, _data);
471 out[_mm_extract_epi32(indices._data, 0)] = tmp[0];
472 out[_mm_extract_epi32(indices._data, 1)] = tmp[1];
473 out[_mm_extract_epi32(indices._data, 2)] = tmp[2];
474 out[_mm_extract_epi32(indices._data, 3)] = tmp[3];
478 template <
typename T>
479 inline void gather(scalarType
const*
p,
const avx2Long4<T>& indices)
481 _data = _mm256_i64gather_pd(
p, indices._data, 8);
484 template <
typename T>
485 inline void scatter(scalarType* out,
const avx2Long4<T>& indices)
const
488 alignas(alignment) scalarArray tmp;
489 _mm256_store_pd(tmp, _data);
491 out[_mm256_extract_epi64(indices._data, 0)] = tmp[0];
492 out[_mm256_extract_epi64(indices._data, 1)] = tmp[1];
493 out[_mm256_extract_epi64(indices._data, 2)] = tmp[2];
494 out[_mm256_extract_epi64(indices._data, 3)] = tmp[3];
499 inline void fma(
const avx2Double4& a,
const avx2Double4& b)
501 _data = _mm256_fmadd_pd(a._data, b._data, _data);
508 inline scalarType operator[](
size_t i)
const
510 alignas(alignment) scalarArray tmp;
515 inline scalarType& operator[](
size_t i)
517 scalarType* tmp =
reinterpret_cast<scalarType*
>(&_data);
522 inline void operator+=(avx2Double4 rhs)
524 _data = _mm256_add_pd(_data, rhs._data);
527 inline void operator-=(avx2Double4 rhs)
529 _data = _mm256_sub_pd(_data, rhs._data);
532 inline void operator*=(avx2Double4 rhs)
534 _data = _mm256_mul_pd(_data, rhs._data);
537 inline void operator/=(avx2Double4 rhs)
539 _data = _mm256_div_pd(_data, rhs._data);
544 inline avx2Double4
operator+(avx2Double4 lhs, avx2Double4 rhs)
546 return _mm256_add_pd(lhs._data, rhs._data);
549 inline avx2Double4
operator-(avx2Double4 lhs, avx2Double4 rhs)
551 return _mm256_sub_pd(lhs._data, rhs._data);
554 inline avx2Double4
operator*(avx2Double4 lhs, avx2Double4 rhs)
556 return _mm256_mul_pd(lhs._data, rhs._data);
559 inline avx2Double4
operator/(avx2Double4 lhs, avx2Double4 rhs)
561 return _mm256_div_pd(lhs._data, rhs._data);
564 inline avx2Double4
sqrt(avx2Double4 in)
566 return _mm256_sqrt_pd(in._data);
569 inline avx2Double4
abs(avx2Double4 in)
572 static const __m256d sign_mask = _mm256_set1_pd(-0.);
573 return _mm256_andnot_pd(sign_mask, in._data);
576 inline avx2Double4
log(avx2Double4 in)
580 alignas(avx2Double4::alignment) avx2Double4::scalarArray tmp;
594 std::vector<avx2Double4, allocator<avx2Double4>> &out)
596 size_t nBlocks = dataLen / 4;
598 alignas(32)
size_t tmp[4] = {0, dataLen, 2*dataLen, 3*dataLen};
599 using index_t = avx2Long4<size_t>;
601 index_t index1 = index0 + 1;
602 index_t index2 = index0 + 2;
603 index_t index3 = index0 + 3;
606 for (
size_t i = 0; i < nBlocks; ++i)
608 out[4*i + 0].gather(in, index0);
609 out[4*i + 1].gather(in, index1);
610 out[4*i + 2].gather(in, index2);
611 out[4*i + 3].gather(in, index3);
619 for (
size_t i = 4 * nBlocks; i < dataLen; ++i)
621 out[i].gather(in, index0);
628 const std::vector<avx2Double4, allocator<avx2Double4>> &in,
634 alignas(32)
size_t tmp[4] = {0, dataLen, 2*dataLen, 3*dataLen};
635 using index_t = avx2Long4<size_t>;
638 for (
size_t i = 0; i < dataLen; ++i)
640 in[i].scatter(out, index0);
655 struct avx2Mask : avx2Long4<std::uint64_t>
658 using avx2Long4::avx2Long4;
660 static constexpr scalarType true_v = -1;
661 static constexpr scalarType false_v = 0;
664 inline avx2Mask
operator>(avx2Double4 lhs, avx2Double4 rhs)
667 return reinterpret_cast<__m256i
>(_mm256_cmp_pd(rhs._data, lhs._data, 1));
670 inline bool operator&&(avx2Mask lhs,
bool rhs)
672 bool tmp = _mm256_testc_si256(lhs._data, _mm256_set1_epi64x(avx2Mask::true_v));
scalarT< T > log(scalarT< T > in)
scalarT< T > operator+(scalarT< T > lhs, scalarT< T > rhs)
void deinterleave_store(const std::vector< scalarT< T >, allocator< scalarT< T >>> &in, size_t dataLen, T *out)
static constexpr struct tinysimd::is_aligned_t is_aligned
scalarT< T > operator-(scalarT< T > lhs, scalarT< T > rhs)
scalarT< T > abs(scalarT< T > in)
scalarMask operator>(scalarT< double > lhs, scalarT< double > rhs)
void load_interleave(const T *in, size_t dataLen, std::vector< scalarT< T >, allocator< scalarT< T >>> &out)
bool operator&&(scalarMask lhs, bool rhs)
scalarT< T > sqrt(scalarT< T > in)
scalarT< T > operator/(scalarT< T > lhs, scalarT< T > rhs)
scalarT< T > operator*(scalarT< T > lhs, scalarT< T > rhs)