35 #ifndef NEKTAR_LIB_LIBUTILITES_SIMDLIB_AVX2_H
36 #define NEKTAR_LIB_LIBUTILITES_SIMDLIB_AVX2_H
38 #if defined(__x86_64__)
39 #include <immintrin.h>
40 #if defined(__INTEL_COMPILER) && !defined(TINYSIMD_HAS_SVML)
41 #define TINYSIMD_HAS_SVML
56 template <
typename scalarType,
int w
idth = 0>
struct avx2
63 #if defined(__AVX2__) && defined(NEKTAR_ENABLE_SIMD_AVX2)
66 template <
typename T>
struct avx2Int8;
67 template <
typename T>
struct avx2Long4;
77 template <>
struct avx2<double>
79 using type = avx2Double4;
81 template <>
struct avx2<float>
83 using type = avx2Float8;
85 template <>
struct avx2<std::int64_t>
87 using type = avx2Long4<std::int64_t>;
89 template <>
struct avx2<std::uint64_t>
91 using type = avx2Long4<std::uint64_t>;
93 template <>
struct avx2<std::int32_t>
95 using type = avx2Int8<std::int32_t>;
97 template <>
struct avx2<std::uint32_t>
99 using type = avx2Int8<std::uint32_t>;
101 template <>
struct avx2<bool, 4>
103 using type = avx2Mask4;
105 template <>
struct avx2<bool, 8>
107 using type = avx2Mask8;
113 template <
typename T>
struct avx2Int8
115 static_assert(std::is_integral<T>::value &&
sizeof(T) == 4,
116 "4 bytes Integral required.");
118 static constexpr
unsigned int width = 8;
119 static constexpr
unsigned int alignment = 32;
121 using scalarType = T;
122 using vectorType = __m256i;
123 using scalarArray = scalarType[width];
129 inline avx2Int8() =
default;
130 inline avx2Int8(
const avx2Int8 &rhs) =
default;
131 inline avx2Int8(
const vectorType &rhs) : _data(rhs)
134 inline avx2Int8(
const scalarType rhs)
136 _data = _mm256_set1_epi32(rhs);
138 explicit inline avx2Int8(scalarArray &rhs)
140 _data = _mm256_load_si256(
reinterpret_cast<vectorType *
>(rhs));
144 inline void store(scalarType *
p)
const
146 _mm256_store_si256(
reinterpret_cast<vectorType *
>(
p), _data);
149 template <
class flag,
150 typename std::enable_if<is_requiring_alignment<flag>::value &&
151 !is_streaming<flag>::value,
153 inline void store(scalarType *
p, flag)
const
155 _mm256_store_si256(
reinterpret_cast<vectorType *
>(
p), _data);
158 template <
class flag,
159 typename std::enable_if<!is_requiring_alignment<flag>::value,
161 inline void store(scalarType *
p, flag)
const
163 _mm256_storeu_si256(
reinterpret_cast<vectorType *
>(
p), _data);
166 inline void load(
const scalarType *
p)
168 _data = _mm256_load_si256(
reinterpret_cast<const vectorType *
>(
p));
171 template <
class flag,
172 typename std::enable_if<is_requiring_alignment<flag>::value &&
173 !is_streaming<flag>::value,
175 inline void load(
const scalarType *
p, flag)
177 _data = _mm256_load_si256(
reinterpret_cast<const vectorType *
>(
p));
180 template <
class flag,
181 typename std::enable_if<!is_requiring_alignment<flag>::value,
183 inline void load(
const scalarType *
p, flag)
185 _data = _mm256_loadu_si256(
reinterpret_cast<const vectorType *
>(
p));
188 inline void broadcast(
const scalarType rhs)
190 _data = _mm256_set1_epi32(rhs);
196 inline scalarType operator[](
size_t i)
const
198 alignas(alignment) scalarArray tmp;
204 template <
typename T>
205 inline avx2Int8<T>
operator+(avx2Int8<T> lhs, avx2Int8<T> rhs)
207 return _mm256_add_epi32(lhs._data, rhs._data);
211 typename T,
typename U,
212 typename =
typename std::enable_if<std::is_arithmetic<U>::value>::type>
213 inline avx2Int8<T>
operator+(avx2Int8<T> lhs, U rhs)
215 return _mm256_add_epi32(lhs._data, _mm256_set1_epi32(rhs));
220 template <
typename T>
struct avx2Long4
222 static_assert(std::is_integral<T>::value &&
sizeof(T) == 8,
223 "8 bytes Integral required.");
225 static constexpr
unsigned int width = 4;
226 static constexpr
unsigned int alignment = 32;
228 using scalarType = T;
229 using vectorType = __m256i;
230 using scalarArray = scalarType[width];
236 inline avx2Long4() =
default;
237 inline avx2Long4(
const avx2Long4 &rhs) =
default;
238 inline avx2Long4(
const vectorType &rhs) : _data(rhs)
241 inline avx2Long4(
const scalarType rhs)
243 _data = _mm256_set1_epi64x(rhs);
245 explicit inline avx2Long4(scalarArray &rhs)
247 _data = _mm256_load_si256(
reinterpret_cast<vectorType *
>(rhs));
251 inline void store(scalarType *
p)
const
253 _mm256_store_si256(
reinterpret_cast<vectorType *
>(
p), _data);
256 template <
class flag,
257 typename std::enable_if<is_requiring_alignment<flag>::value &&
258 !is_streaming<flag>::value,
260 inline void store(scalarType *
p, flag)
const
262 _mm256_store_si256(
reinterpret_cast<vectorType *
>(
p), _data);
265 template <
class flag,
266 typename std::enable_if<!is_requiring_alignment<flag>::value,
268 inline void store(scalarType *
p, flag)
const
270 _mm256_storeu_si256(
reinterpret_cast<vectorType *
>(
p), _data);
273 inline void load(
const scalarType *
p)
275 _data = _mm256_load_si256(
reinterpret_cast<const vectorType *
>(
p));
278 template <
class flag,
279 typename std::enable_if<is_requiring_alignment<flag>::value &&
280 !is_streaming<flag>::value,
282 inline void load(
const scalarType *
p, flag)
284 _data = _mm256_load_si256(
reinterpret_cast<const vectorType *
>(
p));
287 template <
class flag,
288 typename std::enable_if<!is_requiring_alignment<flag>::value,
290 inline void load(
const scalarType *
p, flag)
292 _data = _mm256_loadu_si256(
reinterpret_cast<const vectorType *
>(
p));
295 inline void broadcast(
const scalarType rhs)
297 _data = _mm256_set1_epi64x(rhs);
303 inline scalarType operator[](
size_t i)
const
305 alignas(alignment) scalarArray tmp;
311 template <
typename T>
312 inline avx2Long4<T>
operator+(avx2Long4<T> lhs, avx2Long4<T> rhs)
314 return _mm256_add_epi64(lhs._data, rhs._data);
318 typename T,
typename U,
319 typename =
typename std::enable_if<std::is_arithmetic<U>::value>::type>
320 inline avx2Long4<T>
operator+(avx2Long4<T> lhs, U rhs)
322 return _mm256_add_epi64(lhs._data, _mm256_set1_epi64x(rhs));
329 static constexpr
unsigned width = 4;
330 static constexpr
unsigned alignment = 32;
332 using scalarType = double;
333 using scalarIndexType = std::uint64_t;
334 using vectorType = __m256d;
335 using scalarArray = scalarType[width];
341 inline avx2Double4() =
default;
342 inline avx2Double4(
const avx2Double4 &rhs) =
default;
343 inline avx2Double4(
const vectorType &rhs) : _data(rhs)
346 inline avx2Double4(
const scalarType rhs)
348 _data = _mm256_set1_pd(rhs);
352 inline void store(scalarType *
p)
const
354 _mm256_store_pd(
p, _data);
357 template <
class flag,
358 typename std::enable_if<is_requiring_alignment<flag>::value &&
359 !is_streaming<flag>::value,
361 inline void store(scalarType *
p, flag)
const
363 _mm256_store_pd(
p, _data);
366 template <
class flag,
367 typename std::enable_if<!is_requiring_alignment<flag>::value,
369 inline void store(scalarType *
p, flag)
const
371 _mm256_storeu_pd(
p, _data);
374 template <class flag, typename std::enable_if<is_streaming<flag>::value,
376 inline void store(scalarType *
p, flag)
const
378 _mm256_stream_pd(
p, _data);
382 inline void load(
const scalarType *
p)
384 _data = _mm256_load_pd(
p);
387 template <
class flag,
388 typename std::enable_if<is_requiring_alignment<flag>::value,
390 inline void load(
const scalarType *
p, flag)
392 _data = _mm256_load_pd(
p);
395 template <
class flag,
396 typename std::enable_if<!is_requiring_alignment<flag>::value,
398 inline void load(
const scalarType *
p, flag)
400 _data = _mm256_loadu_pd(
p);
404 inline void broadcast(
const scalarType rhs)
406 _data = _mm256_set1_pd(rhs);
409 #if defined(__SSE2__) && defined(NEKTAR_ENABLE_SIMD_SSE2)
411 template <
typename T>
412 inline void gather(scalarType
const *
p,
const sse2Int4<T> &indices)
414 _data = _mm256_i32gather_pd(
p, indices._data, 8);
417 template <
typename T>
418 inline void scatter(scalarType *out,
const sse2Int4<T> &indices)
const
421 alignas(alignment) scalarArray tmp;
422 _mm256_store_pd(tmp, _data);
424 out[_mm_extract_epi32(indices._data, 0)] = tmp[0];
425 out[_mm_extract_epi32(indices._data, 1)] = tmp[1];
426 out[_mm_extract_epi32(indices._data, 2)] = tmp[2];
427 out[_mm_extract_epi32(indices._data, 3)] = tmp[3];
432 template <
typename T>
433 inline void gather(scalarType
const *
p,
const avx2Long4<T> &indices)
435 _data = _mm256_i64gather_pd(
p, indices._data, 8);
438 template <
typename T>
439 inline void scatter(scalarType *out,
const avx2Long4<T> &indices)
const
442 alignas(alignment) scalarArray tmp;
443 _mm256_store_pd(tmp, _data);
445 out[_mm256_extract_epi64(indices._data, 0)] = tmp[0];
446 out[_mm256_extract_epi64(indices._data, 1)] = tmp[1];
447 out[_mm256_extract_epi64(indices._data, 2)] = tmp[2];
448 out[_mm256_extract_epi64(indices._data, 3)] = tmp[3];
453 inline void fma(
const avx2Double4 &a,
const avx2Double4 &b)
455 _data = _mm256_fmadd_pd(a._data, b._data, _data);
461 inline scalarType operator[](
size_t i)
const
463 alignas(alignment) scalarArray tmp;
469 inline void operator+=(avx2Double4 rhs)
471 _data = _mm256_add_pd(_data, rhs._data);
474 inline void operator-=(avx2Double4 rhs)
476 _data = _mm256_sub_pd(_data, rhs._data);
479 inline void operator*=(avx2Double4 rhs)
481 _data = _mm256_mul_pd(_data, rhs._data);
484 inline void operator/=(avx2Double4 rhs)
486 _data = _mm256_div_pd(_data, rhs._data);
490 inline avx2Double4
operator+(avx2Double4 lhs, avx2Double4 rhs)
492 return _mm256_add_pd(lhs._data, rhs._data);
495 inline avx2Double4
operator-(avx2Double4 lhs, avx2Double4 rhs)
497 return _mm256_sub_pd(lhs._data, rhs._data);
500 inline avx2Double4
operator*(avx2Double4 lhs, avx2Double4 rhs)
502 return _mm256_mul_pd(lhs._data, rhs._data);
505 inline avx2Double4
operator/(avx2Double4 lhs, avx2Double4 rhs)
507 return _mm256_div_pd(lhs._data, rhs._data);
510 inline avx2Double4
sqrt(avx2Double4 in)
512 return _mm256_sqrt_pd(in._data);
515 inline avx2Double4
abs(avx2Double4 in)
518 static const __m256d sign_mask = _mm256_set1_pd(-0.);
519 return _mm256_andnot_pd(sign_mask, in._data);
522 inline avx2Double4
log(avx2Double4 in)
524 #if defined(TINYSIMD_HAS_SVML)
525 return _mm256_log_pd(in._data);
529 alignas(avx2Double4::alignment) avx2Double4::scalarArray tmp;
542 const double *in,
size_t dataLen,
543 std::vector<avx2Double4, allocator<avx2Double4>> &out)
545 alignas(avx2Double4::alignment)
546 size_t tmp[avx2Double4::width] = {0, dataLen, 2 * dataLen, 3 * dataLen};
547 using index_t = avx2Long4<size_t>;
549 index_t index1 = index0 + 1;
550 index_t index2 = index0 + 2;
551 index_t index3 = index0 + 3;
554 constexpr uint16_t unrl = 4;
555 size_t nBlocks = dataLen / unrl;
556 for (
size_t i = 0; i < nBlocks; ++i)
558 out[unrl * i + 0].gather(in, index0);
559 out[unrl * i + 1].gather(in, index1);
560 out[unrl * i + 2].gather(in, index2);
561 out[unrl * i + 3].gather(in, index3);
562 index0 = index0 + unrl;
563 index1 = index1 + unrl;
564 index2 = index2 + unrl;
565 index3 = index3 + unrl;
569 for (
size_t i = unrl * nBlocks; i < dataLen; ++i)
571 out[i].gather(in, index0);
577 const std::vector<avx2Double4, allocator<avx2Double4>> &in,
size_t dataLen,
580 alignas(avx2Double4::alignment)
581 size_t tmp[avx2Double4::width] = {0, dataLen, 2 * dataLen, 3 * dataLen};
582 using index_t = avx2Long4<size_t>;
585 for (
size_t i = 0; i < dataLen; ++i)
587 in[i].scatter(out, index0);
596 static constexpr
unsigned width = 8;
597 static constexpr
unsigned alignment = 32;
599 using scalarType = float;
600 using scalarIndexType = std::uint32_t;
601 using vectorType = __m256;
602 using scalarArray = scalarType[width];
608 inline avx2Float8() =
default;
609 inline avx2Float8(
const avx2Float8 &rhs) =
default;
610 inline avx2Float8(
const vectorType &rhs) : _data(rhs)
613 inline avx2Float8(
const scalarType rhs)
615 _data = _mm256_set1_ps(rhs);
619 inline void store(scalarType *
p)
const
621 _mm256_store_ps(
p, _data);
624 template <
class flag,
625 typename std::enable_if<is_requiring_alignment<flag>::value &&
626 !is_streaming<flag>::value,
628 inline void store(scalarType *
p, flag)
const
630 _mm256_store_ps(
p, _data);
633 template <
class flag,
634 typename std::enable_if<!is_requiring_alignment<flag>::value,
636 inline void store(scalarType *
p, flag)
const
638 _mm256_storeu_ps(
p, _data);
641 template <class flag, typename std::enable_if<is_streaming<flag>::value,
643 inline void store(scalarType *
p, flag)
const
645 _mm256_stream_ps(
p, _data);
649 inline void load(
const scalarType *
p)
651 _data = _mm256_load_ps(
p);
654 template <
class flag,
655 typename std::enable_if<is_requiring_alignment<flag>::value,
657 inline void load(
const scalarType *
p, flag)
659 _data = _mm256_load_ps(
p);
662 template <
class flag,
663 typename std::enable_if<!is_requiring_alignment<flag>::value,
665 inline void load(
const scalarType *
p, flag)
667 _data = _mm256_loadu_ps(
p);
671 inline void broadcast(
const scalarType rhs)
673 _data = _mm256_set1_ps(rhs);
677 template <
typename T>
678 inline void gather(scalarType
const *
p,
const avx2Int8<T> &indices)
680 _data = _mm256_i32gather_ps(
p, indices._data, 4);
683 template <
typename T>
684 inline void scatter(scalarType *out,
const avx2Int8<T> &indices)
const
687 alignas(alignment) scalarArray tmp;
688 _mm256_store_ps(tmp, _data);
690 out[_mm256_extract_epi32(indices._data, 0)] = tmp[0];
691 out[_mm256_extract_epi32(indices._data, 1)] = tmp[1];
692 out[_mm256_extract_epi32(indices._data, 2)] = tmp[2];
693 out[_mm256_extract_epi32(indices._data, 3)] = tmp[3];
694 out[_mm256_extract_epi32(indices._data, 4)] = tmp[4];
695 out[_mm256_extract_epi32(indices._data, 5)] = tmp[5];
696 out[_mm256_extract_epi32(indices._data, 6)] = tmp[6];
697 out[_mm256_extract_epi32(indices._data, 7)] = tmp[7];
702 inline void fma(
const avx2Float8 &a,
const avx2Float8 &b)
704 _data = _mm256_fmadd_ps(a._data, b._data, _data);
710 inline scalarType operator[](
size_t i)
const
712 alignas(alignment) scalarArray tmp;
717 inline scalarType &operator[](
size_t i)
719 scalarType *tmp =
reinterpret_cast<scalarType *
>(&_data);
724 inline void operator+=(avx2Float8 rhs)
726 _data = _mm256_add_ps(_data, rhs._data);
729 inline void operator-=(avx2Float8 rhs)
731 _data = _mm256_sub_ps(_data, rhs._data);
734 inline void operator*=(avx2Float8 rhs)
736 _data = _mm256_mul_ps(_data, rhs._data);
739 inline void operator/=(avx2Float8 rhs)
741 _data = _mm256_div_ps(_data, rhs._data);
745 inline avx2Float8
operator+(avx2Float8 lhs, avx2Float8 rhs)
747 return _mm256_add_ps(lhs._data, rhs._data);
750 inline avx2Float8
operator-(avx2Float8 lhs, avx2Float8 rhs)
752 return _mm256_sub_ps(lhs._data, rhs._data);
755 inline avx2Float8
operator*(avx2Float8 lhs, avx2Float8 rhs)
757 return _mm256_mul_ps(lhs._data, rhs._data);
760 inline avx2Float8
operator/(avx2Float8 lhs, avx2Float8 rhs)
762 return _mm256_div_ps(lhs._data, rhs._data);
765 inline avx2Float8
sqrt(avx2Float8 in)
767 return _mm256_sqrt_ps(in._data);
770 inline avx2Float8
abs(avx2Float8 in)
773 static const __m256 sign_mask = _mm256_set1_ps(-0.);
774 return _mm256_andnot_ps(sign_mask, in._data);
777 inline avx2Float8
log(avx2Float8 in)
781 alignas(avx2Float8::alignment) avx2Float8::scalarArray tmp;
797 std::vector<avx2Float8, allocator<avx2Float8>> &out)
800 alignas(avx2Float8::alignment) avx2Float8::scalarIndexType tmp[8] = {
801 0, dataLen, 2 * dataLen, 3 * dataLen,
802 4 * dataLen, 5 * dataLen, 6 * dataLen, 7 * dataLen};
804 using index_t = avx2Int8<avx2Float8::scalarIndexType>;
806 index_t index1 = index0 + 1;
807 index_t index2 = index0 + 2;
808 index_t index3 = index0 + 3;
811 size_t nBlocks = dataLen / 4;
812 for (
size_t i = 0; i < nBlocks; ++i)
814 out[4 * i + 0].gather(in, index0);
815 out[4 * i + 1].gather(in, index1);
816 out[4 * i + 2].gather(in, index2);
817 out[4 * i + 3].gather(in, index3);
825 for (
size_t i = 4 * nBlocks; i < dataLen; ++i)
827 out[i].gather(in, index0);
833 const std::vector<avx2Float8, allocator<avx2Float8>> &in,
834 std::uint32_t dataLen,
float *out)
836 alignas(avx2Float8::alignment) avx2Float8::scalarIndexType tmp[8] = {
837 0, dataLen, 2 * dataLen, 3 * dataLen,
838 4 * dataLen, 5 * dataLen, 6 * dataLen, 7 * dataLen};
839 using index_t = avx2Int8<avx2Float8::scalarIndexType>;
842 for (
size_t i = 0; i < dataLen; ++i)
844 in[i].scatter(out, index0);
858 struct avx2Mask4 : avx2Long4<std::uint64_t>
861 using avx2Long4::avx2Long4;
863 static constexpr scalarType true_v = -1;
864 static constexpr scalarType false_v = 0;
867 inline avx2Mask4
operator>(avx2Double4 lhs, avx2Double4 rhs)
869 return reinterpret_cast<__m256i
>(
870 _mm256_cmp_pd(lhs._data, rhs._data, _CMP_GT_OQ));
873 inline bool operator&&(avx2Mask4 lhs,
bool rhs)
876 _mm256_testc_si256(lhs._data, _mm256_set1_epi64x(avx2Mask4::true_v));
881 struct avx2Mask8 : avx2Int8<std::uint32_t>
884 using avx2Int8::avx2Int8;
886 static constexpr scalarType true_v = -1;
887 static constexpr scalarType false_v = 0;
890 inline avx2Mask8
operator>(avx2Float8 lhs, avx2Float8 rhs)
892 return reinterpret_cast<__m256i
>(_mm256_cmp_ps(rhs._data, lhs._data, 1));
895 inline bool operator&&(avx2Mask8 lhs,
bool rhs)
898 _mm256_testc_si256(lhs._data, _mm256_set1_epi64x(avx2Mask8::true_v));
scalarT< T > log(scalarT< T > in)
scalarT< T > operator+(scalarT< T > lhs, scalarT< T > rhs)
void deinterleave_store(const std::vector< scalarT< T >, allocator< scalarT< T >>> &in, size_t dataLen, T *out)
static constexpr struct tinysimd::is_aligned_t is_aligned
scalarT< T > operator-(scalarT< T > lhs, scalarT< T > rhs)
scalarT< T > abs(scalarT< T > in)
scalarMask operator>(scalarT< double > lhs, scalarT< double > rhs)
void load_interleave(const T *in, size_t dataLen, std::vector< scalarT< T >, allocator< scalarT< T >>> &out)
bool operator&&(scalarMask lhs, bool rhs)
scalarT< T > sqrt(scalarT< T > in)
scalarT< T > operator/(scalarT< T > lhs, scalarT< T > rhs)
scalarT< T > operator*(scalarT< T > lhs, scalarT< T > rhs)