35#ifndef NEKTAR_LIB_LIBUTILITES_SIMDLIB_AVX2_H
36#define NEKTAR_LIB_LIBUTILITES_SIMDLIB_AVX2_H
38#if defined(__x86_64__)
40#if defined(__INTEL_COMPILER) && !defined(TINYSIMD_HAS_SVML)
41#define TINYSIMD_HAS_SVML
53template <
typename scalarType,
int w
idth = 0>
struct avx2
60#if defined(__AVX2__) && defined(NEKTAR_ENABLE_SIMD_AVX2)
66template <
typename T>
struct avx2Long4;
67template <
typename T>
struct avx2Int8;
77template <>
struct avx2<double>
79 using type = avx2Double4;
81template <>
struct avx2<float>
83 using type = avx2Float8;
87template <>
struct avx2<
std::int64_t>
89 using type = avx2Long4<std::int64_t>;
91template <>
struct avx2<
std::uint64_t>
93 using type = avx2Long4<std::uint64_t>;
96template <>
struct avx2<
std::size_t>
98 using type = avx2Long4<std::size_t>;
101template <>
struct avx2<
std::int32_t>
103 using type = avx2Int8<std::int32_t>;
105template <>
struct avx2<
std::uint32_t>
107 using type = avx2Int8<std::uint32_t>;
110template <>
struct avx2<
std::int64_t, 4>
112 using type = avx2Long4<std::int64_t>;
114template <>
struct avx2<
std::uint64_t, 4>
116 using type = avx2Long4<std::uint64_t>;
118#if defined(__APPLE__)
119template <>
struct avx2<
std::size_t, 4>
121 using type = avx2Long4<std::size_t>;
124template <>
struct avx2<
std::int32_t, 4>
126 using type = sse2Int4<std::int32_t>;
128template <>
struct avx2<
std::uint32_t, 4>
130 using type = sse2Int4<std::uint32_t>;
132template <>
struct avx2<
std::int32_t, 8>
134 using type = avx2Int8<std::int32_t>;
136template <>
struct avx2<
std::uint32_t, 8>
138 using type = avx2Int8<std::uint32_t>;
141template <>
struct avx2<bool, 4>
143 using type = avx2Mask4;
145template <>
struct avx2<bool, 8>
147 using type = avx2Mask8;
153template <
typename T>
struct avx2Int8
155 static_assert(std::is_integral<T>::value &&
sizeof(T) == 4,
156 "4 bytes Integral required.");
158 static constexpr unsigned int width = 8;
159 static constexpr unsigned int alignment = 32;
161 using scalarType = T;
162 using vectorType = __m256i;
163 using scalarArray = scalarType[width];
169 inline avx2Int8() =
default;
170 inline avx2Int8(
const avx2Int8 &rhs) =
default;
171 inline avx2Int8(
const vectorType &rhs) : _data(rhs)
174 inline avx2Int8(
const scalarType rhs)
176 _data = _mm256_set1_epi32(rhs);
178 explicit inline avx2Int8(scalarArray &rhs)
180 _data = _mm256_load_si256(
reinterpret_cast<vectorType *
>(rhs));
184 inline avx2Int8 &operator=(
const avx2Int8 &) =
default;
187 inline void store(scalarType *
p)
const
189 _mm256_store_si256(
reinterpret_cast<vectorType *
>(
p), _data);
192 template <
class flag,
193 typename std::enable_if<is_requiring_alignment<flag>::value &&
194 !is_streaming<flag>::value,
196 inline void store(scalarType *
p, flag)
const
198 _mm256_store_si256(
reinterpret_cast<vectorType *
>(
p), _data);
201 template <
class flag,
202 typename std::enable_if<!is_requiring_alignment<flag>::value,
204 inline void store(scalarType *
p, flag)
const
206 _mm256_storeu_si256(
reinterpret_cast<vectorType *
>(
p), _data);
209 inline void load(
const scalarType *
p)
211 _data = _mm256_load_si256(
reinterpret_cast<const vectorType *
>(
p));
214 template <
class flag,
215 typename std::enable_if<is_requiring_alignment<flag>::value &&
216 !is_streaming<flag>::value,
218 inline void load(
const scalarType *
p, flag)
220 _data = _mm256_load_si256(
reinterpret_cast<const vectorType *
>(
p));
223 template <
class flag,
224 typename std::enable_if<!is_requiring_alignment<flag>::value,
226 inline void load(
const scalarType *
p, flag)
228 _data = _mm256_loadu_si256(
reinterpret_cast<const vectorType *
>(
p));
231 inline void broadcast(
const scalarType rhs)
233 _data = _mm256_set1_epi32(rhs);
239 inline scalarType operator[](
size_t i)
const
241 alignas(alignment) scalarArray tmp;
246 inline scalarType &operator[](
size_t i)
248 scalarType *tmp =
reinterpret_cast<scalarType *
>(&_data);
254inline avx2Int8<T>
operator+(avx2Int8<T> lhs, avx2Int8<T> rhs)
256 return _mm256_add_epi32(lhs._data, rhs._data);
260 typename T,
typename U,
261 typename =
typename std::enable_if<std::is_arithmetic<U>::value>::type>
262inline avx2Int8<T>
operator+(avx2Int8<T> lhs, U rhs)
264 return _mm256_add_epi32(lhs._data, _mm256_set1_epi32(rhs));
269template <
typename T>
struct avx2Long4
271 static_assert(std::is_integral<T>::value &&
sizeof(T) == 8,
272 "8 bytes Integral required.");
274 static constexpr unsigned int width = 4;
275 static constexpr unsigned int alignment = 32;
277 using scalarType = T;
278 using vectorType = __m256i;
279 using scalarArray = scalarType[width];
285 inline avx2Long4() =
default;
286 inline avx2Long4(
const avx2Long4 &rhs) =
default;
287 inline avx2Long4(
const vectorType &rhs) : _data(rhs)
290 inline avx2Long4(
const scalarType rhs)
292 _data = _mm256_set1_epi64x(rhs);
294 explicit inline avx2Long4(scalarArray &rhs)
296 _data = _mm256_load_si256(
reinterpret_cast<vectorType *
>(rhs));
300 inline avx2Long4 &operator=(
const avx2Long4 &) =
default;
303 inline void store(scalarType *
p)
const
305 _mm256_store_si256(
reinterpret_cast<vectorType *
>(
p), _data);
308 template <
class flag,
309 typename std::enable_if<is_requiring_alignment<flag>::value &&
310 !is_streaming<flag>::value,
312 inline void store(scalarType *
p, flag)
const
314 _mm256_store_si256(
reinterpret_cast<vectorType *
>(
p), _data);
317 template <
class flag,
318 typename std::enable_if<!is_requiring_alignment<flag>::value,
320 inline void store(scalarType *
p, flag)
const
322 _mm256_storeu_si256(
reinterpret_cast<vectorType *
>(
p), _data);
325 inline void load(
const scalarType *
p)
327 _data = _mm256_load_si256(
reinterpret_cast<const vectorType *
>(
p));
330 template <
class flag,
331 typename std::enable_if<is_requiring_alignment<flag>::value &&
332 !is_streaming<flag>::value,
334 inline void load(
const scalarType *
p, flag)
336 _data = _mm256_load_si256(
reinterpret_cast<const vectorType *
>(
p));
339 template <
class flag,
340 typename std::enable_if<!is_requiring_alignment<flag>::value,
342 inline void load(
const scalarType *
p, flag)
344 _data = _mm256_loadu_si256(
reinterpret_cast<const vectorType *
>(
p));
347 inline void broadcast(
const scalarType rhs)
349 _data = _mm256_set1_epi64x(rhs);
355 inline scalarType operator[](
size_t i)
const
357 alignas(alignment) scalarArray tmp;
362 inline scalarType &operator[](
size_t i)
364 scalarType *tmp =
reinterpret_cast<scalarType *
>(&_data);
370inline avx2Long4<T>
operator+(avx2Long4<T> lhs, avx2Long4<T> rhs)
372 return _mm256_add_epi64(lhs._data, rhs._data);
376 typename T,
typename U,
377 typename =
typename std::enable_if<std::is_arithmetic<U>::value>::type>
378inline avx2Long4<T>
operator+(avx2Long4<T> lhs, U rhs)
380 return _mm256_add_epi64(lhs._data, _mm256_set1_epi64x(rhs));
387 static constexpr unsigned width = 4;
388 static constexpr unsigned alignment = 32;
390 using scalarType = double;
391 using scalarIndexType = std::uint64_t;
392 using vectorType = __m256d;
393 using scalarArray = scalarType[width];
399 inline avx2Double4() =
default;
400 inline avx2Double4(
const avx2Double4 &rhs) =
default;
401 inline avx2Double4(
const vectorType &rhs) : _data(rhs)
404 inline avx2Double4(
const scalarType rhs)
406 _data = _mm256_set1_pd(rhs);
410 inline avx2Double4 &operator=(
const avx2Double4 &) =
default;
413 inline void store(scalarType *
p)
const
415 _mm256_store_pd(
p, _data);
418 template <
class flag,
419 typename std::enable_if<is_requiring_alignment<flag>::value &&
420 !is_streaming<flag>::value,
422 inline void store(scalarType *
p, flag)
const
424 _mm256_store_pd(
p, _data);
427 template <
class flag,
428 typename std::enable_if<!is_requiring_alignment<flag>::value,
430 inline void store(scalarType *
p, flag)
const
432 _mm256_storeu_pd(
p, _data);
435 template <class flag, typename std::enable_if<is_streaming<flag>::value,
437 inline void store(scalarType *
p, flag)
const
439 _mm256_stream_pd(
p, _data);
443 inline void load(
const scalarType *
p)
445 _data = _mm256_load_pd(
p);
448 template <
class flag,
449 typename std::enable_if<is_requiring_alignment<flag>::value,
451 inline void load(
const scalarType *
p, flag)
453 _data = _mm256_load_pd(
p);
456 template <
class flag,
457 typename std::enable_if<!is_requiring_alignment<flag>::value,
459 inline void load(
const scalarType *
p, flag)
461 _data = _mm256_loadu_pd(
p);
465 inline void broadcast(
const scalarType rhs)
467 _data = _mm256_set1_pd(rhs);
470#if defined(__SSE2__) && defined(NEKTAR_ENABLE_SIMD_SSE2)
472 template <
typename T>
473 inline void gather(scalarType
const *
p,
const sse2Int4<T> &indices)
475 _data = _mm256_i32gather_pd(
p, indices._data, 8);
478 template <
typename T>
479 inline void scatter(scalarType *out,
const sse2Int4<T> &indices)
const
482 alignas(alignment) scalarArray tmp;
483 _mm256_store_pd(tmp, _data);
485 out[_mm_extract_epi32(indices._data, 0)] = tmp[0];
486 out[_mm_extract_epi32(indices._data, 1)] = tmp[1];
487 out[_mm_extract_epi32(indices._data, 2)] = tmp[2];
488 out[_mm_extract_epi32(indices._data, 3)] = tmp[3];
493 template <
typename T>
494 inline void gather(scalarType
const *
p,
const avx2Long4<T> &indices)
496 _data = _mm256_i64gather_pd(
p, indices._data, 8);
499 template <
typename T>
500 inline void scatter(scalarType *out,
const avx2Long4<T> &indices)
const
503 alignas(alignment) scalarArray tmp;
504 _mm256_store_pd(tmp, _data);
506 out[_mm256_extract_epi64(indices._data, 0)] = tmp[0];
507 out[_mm256_extract_epi64(indices._data, 1)] = tmp[1];
508 out[_mm256_extract_epi64(indices._data, 2)] = tmp[2];
509 out[_mm256_extract_epi64(indices._data, 3)] = tmp[3];
514 inline void fma(
const avx2Double4 &a,
const avx2Double4 &b)
516 _data = _mm256_fmadd_pd(a._data, b._data, _data);
522 inline scalarType operator[](
size_t i)
const
524 alignas(alignment) scalarArray tmp;
529 inline scalarType &operator[](
size_t i)
531 scalarType *tmp =
reinterpret_cast<scalarType *
>(&_data);
536 inline void operator+=(avx2Double4 rhs)
538 _data = _mm256_add_pd(_data, rhs._data);
541 inline void operator-=(avx2Double4 rhs)
543 _data = _mm256_sub_pd(_data, rhs._data);
546 inline void operator*=(avx2Double4 rhs)
548 _data = _mm256_mul_pd(_data, rhs._data);
551 inline void operator/=(avx2Double4 rhs)
553 _data = _mm256_div_pd(_data, rhs._data);
557inline avx2Double4
operator+(avx2Double4 lhs, avx2Double4 rhs)
559 return _mm256_add_pd(lhs._data, rhs._data);
562inline avx2Double4
operator-(avx2Double4 lhs, avx2Double4 rhs)
564 return _mm256_sub_pd(lhs._data, rhs._data);
567inline avx2Double4
operator*(avx2Double4 lhs, avx2Double4 rhs)
569 return _mm256_mul_pd(lhs._data, rhs._data);
572inline avx2Double4
operator/(avx2Double4 lhs, avx2Double4 rhs)
574 return _mm256_div_pd(lhs._data, rhs._data);
577inline avx2Double4
sqrt(avx2Double4 in)
579 return _mm256_sqrt_pd(in._data);
582inline avx2Double4
abs(avx2Double4 in)
585 static const __m256d sign_mask = _mm256_set1_pd(-0.);
586 return _mm256_andnot_pd(sign_mask, in._data);
589inline avx2Double4
log(avx2Double4 in)
591#if defined(TINYSIMD_HAS_SVML)
592 return _mm256_log_pd(in._data);
596 alignas(avx2Double4::alignment) avx2Double4::scalarArray tmp;
609 const double *in, std::uint32_t dataLen,
610 std::vector<avx2Double4, allocator<avx2Double4>> &out)
612 alignas(avx2Double4::alignment)
613 size_t tmp[avx2Double4::width] = {0, dataLen, 2 * dataLen, 3 * dataLen};
614 using index_t = avx2Long4<size_t>;
616 index_t index1 = index0 + 1;
617 index_t index2 = index0 + 2;
618 index_t index3 = index0 + 3;
621 constexpr uint16_t unrl = 4;
622 size_t nBlocks = dataLen / unrl;
623 for (
size_t i = 0; i < nBlocks; ++i)
625 out[unrl * i + 0].gather(in, index0);
626 out[unrl * i + 1].gather(in, index1);
627 out[unrl * i + 2].gather(in, index2);
628 out[unrl * i + 3].gather(in, index3);
629 index0 = index0 + unrl;
630 index1 = index1 + unrl;
631 index2 = index2 + unrl;
632 index3 = index3 + unrl;
636 for (
size_t i = unrl * nBlocks; i < dataLen; ++i)
638 out[i].gather(in, index0);
644 const std::vector<avx2Double4, allocator<avx2Double4>> &in,
645 std::uint32_t dataLen,
double *out)
647 alignas(avx2Double4::alignment)
648 size_t tmp[avx2Double4::width] = {0, dataLen, 2 * dataLen, 3 * dataLen};
649 using index_t = avx2Long4<size_t>;
652 for (
size_t i = 0; i < dataLen; ++i)
654 in[i].scatter(out, index0);
663 static constexpr unsigned width = 8;
664 static constexpr unsigned alignment = 32;
666 using scalarType = float;
667 using scalarIndexType = std::uint32_t;
668 using vectorType = __m256;
669 using scalarArray = scalarType[width];
675 inline avx2Float8() =
default;
676 inline avx2Float8(
const avx2Float8 &rhs) =
default;
677 inline avx2Float8(
const vectorType &rhs) : _data(rhs)
680 inline avx2Float8(
const scalarType rhs)
682 _data = _mm256_set1_ps(rhs);
686 inline avx2Float8 &operator=(
const avx2Float8 &) =
default;
689 inline void store(scalarType *
p)
const
691 _mm256_store_ps(
p, _data);
694 template <
class flag,
695 typename std::enable_if<is_requiring_alignment<flag>::value &&
696 !is_streaming<flag>::value,
698 inline void store(scalarType *
p, flag)
const
700 _mm256_store_ps(
p, _data);
703 template <
class flag,
704 typename std::enable_if<!is_requiring_alignment<flag>::value,
706 inline void store(scalarType *
p, flag)
const
708 _mm256_storeu_ps(
p, _data);
711 template <class flag, typename std::enable_if<is_streaming<flag>::value,
713 inline void store(scalarType *
p, flag)
const
715 _mm256_stream_ps(
p, _data);
719 inline void load(
const scalarType *
p)
721 _data = _mm256_load_ps(
p);
724 template <
class flag,
725 typename std::enable_if<is_requiring_alignment<flag>::value,
727 inline void load(
const scalarType *
p, flag)
729 _data = _mm256_load_ps(
p);
732 template <
class flag,
733 typename std::enable_if<!is_requiring_alignment<flag>::value,
735 inline void load(
const scalarType *
p, flag)
737 _data = _mm256_loadu_ps(
p);
741 inline void broadcast(
const scalarType rhs)
743 _data = _mm256_set1_ps(rhs);
747 template <
typename T>
748 inline void gather(scalarType
const *
p,
const avx2Int8<T> &indices)
750 _data = _mm256_i32gather_ps(
p, indices._data, 4);
753 template <
typename T>
754 inline void scatter(scalarType *out,
const avx2Int8<T> &indices)
const
757 alignas(alignment) scalarArray tmp;
758 _mm256_store_ps(tmp, _data);
760 out[_mm256_extract_epi32(indices._data, 0)] = tmp[0];
761 out[_mm256_extract_epi32(indices._data, 1)] = tmp[1];
762 out[_mm256_extract_epi32(indices._data, 2)] = tmp[2];
763 out[_mm256_extract_epi32(indices._data, 3)] = tmp[3];
764 out[_mm256_extract_epi32(indices._data, 4)] = tmp[4];
765 out[_mm256_extract_epi32(indices._data, 5)] = tmp[5];
766 out[_mm256_extract_epi32(indices._data, 6)] = tmp[6];
767 out[_mm256_extract_epi32(indices._data, 7)] = tmp[7];
772 inline void fma(
const avx2Float8 &a,
const avx2Float8 &b)
774 _data = _mm256_fmadd_ps(a._data, b._data, _data);
780 inline scalarType operator[](
size_t i)
const
782 alignas(alignment) scalarArray tmp;
787 inline scalarType &operator[](
size_t i)
789 scalarType *tmp =
reinterpret_cast<scalarType *
>(&_data);
793 inline void operator+=(avx2Float8 rhs)
795 _data = _mm256_add_ps(_data, rhs._data);
798 inline void operator-=(avx2Float8 rhs)
800 _data = _mm256_sub_ps(_data, rhs._data);
803 inline void operator*=(avx2Float8 rhs)
805 _data = _mm256_mul_ps(_data, rhs._data);
808 inline void operator/=(avx2Float8 rhs)
810 _data = _mm256_div_ps(_data, rhs._data);
814inline avx2Float8
operator+(avx2Float8 lhs, avx2Float8 rhs)
816 return _mm256_add_ps(lhs._data, rhs._data);
819inline avx2Float8
operator-(avx2Float8 lhs, avx2Float8 rhs)
821 return _mm256_sub_ps(lhs._data, rhs._data);
824inline avx2Float8
operator*(avx2Float8 lhs, avx2Float8 rhs)
826 return _mm256_mul_ps(lhs._data, rhs._data);
829inline avx2Float8
operator/(avx2Float8 lhs, avx2Float8 rhs)
831 return _mm256_div_ps(lhs._data, rhs._data);
834inline avx2Float8
sqrt(avx2Float8 in)
836 return _mm256_sqrt_ps(in._data);
839inline avx2Float8
abs(avx2Float8 in)
842 static const __m256 sign_mask = _mm256_set1_ps(-0.);
843 return _mm256_andnot_ps(sign_mask, in._data);
846inline avx2Float8
log(avx2Float8 in)
850 alignas(avx2Float8::alignment) avx2Float8::scalarArray tmp;
866 std::vector<avx2Float8, allocator<avx2Float8>> &out)
869 alignas(avx2Float8::alignment) avx2Float8::scalarIndexType tmp[8] = {
870 0, dataLen, 2 * dataLen, 3 * dataLen,
871 4 * dataLen, 5 * dataLen, 6 * dataLen, 7 * dataLen};
873 using index_t = avx2Int8<avx2Float8::scalarIndexType>;
875 index_t index1 = index0 + 1;
876 index_t index2 = index0 + 2;
877 index_t index3 = index0 + 3;
880 size_t nBlocks = dataLen / 4;
881 for (
size_t i = 0; i < nBlocks; ++i)
883 out[4 * i + 0].gather(in, index0);
884 out[4 * i + 1].gather(in, index1);
885 out[4 * i + 2].gather(in, index2);
886 out[4 * i + 3].gather(in, index3);
894 for (
size_t i = 4 * nBlocks; i < dataLen; ++i)
896 out[i].gather(in, index0);
902 const std::vector<avx2Float8, allocator<avx2Float8>> &in,
903 std::uint32_t dataLen,
float *out)
905 alignas(avx2Float8::alignment) avx2Float8::scalarIndexType tmp[8] = {
906 0, dataLen, 2 * dataLen, 3 * dataLen,
907 4 * dataLen, 5 * dataLen, 6 * dataLen, 7 * dataLen};
908 using index_t = avx2Int8<avx2Float8::scalarIndexType>;
911 for (
size_t i = 0; i < dataLen; ++i)
913 in[i].scatter(out, index0);
927struct avx2Mask4 : avx2Long4<std::uint64_t>
930 using avx2Long4::avx2Long4;
932 static constexpr scalarType true_v = -1;
933 static constexpr scalarType false_v = 0;
936inline avx2Mask4
operator>(avx2Double4 lhs, avx2Double4 rhs)
938 return reinterpret_cast<__m256i
>(
939 _mm256_cmp_pd(lhs._data, rhs._data, _CMP_GT_OQ));
942inline bool operator&&(avx2Mask4 lhs,
bool rhs)
945 _mm256_testc_si256(lhs._data, _mm256_set1_epi64x(avx2Mask4::true_v));
950struct avx2Mask8 : avx2Int8<std::uint32_t>
953 using avx2Int8::avx2Int8;
955 static constexpr scalarType true_v = -1;
956 static constexpr scalarType false_v = 0;
959inline avx2Mask8
operator>(avx2Float8 lhs, avx2Float8 rhs)
961 return reinterpret_cast<__m256i
>(_mm256_cmp_ps(rhs._data, lhs._data, 1));
964inline bool operator&&(avx2Mask8 lhs,
bool rhs)
967 _mm256_testc_si256(lhs._data, _mm256_set1_epi64x(avx2Mask8::true_v));
void load_interleave(const T *in, size_t dataLen, std::vector< scalarT< T >, allocator< scalarT< T > > > &out)
scalarT< T > abs(scalarT< T > in)
static constexpr struct tinysimd::is_aligned_t is_aligned
scalarT< T > operator-(scalarT< T > lhs, scalarT< T > rhs)
scalarT< T > operator/(scalarT< T > lhs, scalarT< T > rhs)
scalarT< T > log(scalarT< T > in)
scalarT< T > operator*(scalarT< T > lhs, scalarT< T > rhs)
scalarMask operator>(scalarT< double > lhs, scalarT< double > rhs)
bool operator&&(scalarMask lhs, bool rhs)
scalarT< T > sqrt(scalarT< T > in)
void deinterleave_store(const std::vector< scalarT< T >, allocator< scalarT< T > > > &in, size_t dataLen, T *out)
scalarT< T > operator+(scalarT< T > lhs, scalarT< T > rhs)