35#ifndef NEKTAR_LIB_LIBUTILITES_SIMDLIB_AVX512_H
36#define NEKTAR_LIB_LIBUTILITES_SIMDLIB_AVX512_H
38#if defined(__x86_64__)
40#if defined(__INTEL_COMPILER) && !defined(TINYSIMD_HAS_SVML)
41#define TINYSIMD_HAS_SVML
55template <
typename scalarType,
int w
idth = 0>
struct avx512
62#if defined(__AVX512F__) && defined(NEKTAR_ENABLE_SIMD_AVX512)
65template <
typename T>
struct avx512Long8;
66template <
typename T>
struct avx512Int16;
76template <>
struct avx512<double>
78 using type = avx512Double8;
80template <>
struct avx512<float>
82 using type = avx512Float16;
86template <>
struct avx512<std::int64_t>
88 using type = avx512Long8<std::int64_t>;
90template <>
struct avx512<std::uint64_t>
92 using type = avx512Long8<std::uint64_t>;
95template <>
struct avx512<std::size_t>
97 using type = avx512Long8<std::size_t>;
100template <>
struct avx512<std::int32_t>
102 using type = avx512Int16<std::int32_t>;
104template <>
struct avx512<std::uint32_t>
106 using type = avx512Int16<std::uint32_t>;
109template <>
struct avx512<std::int64_t, 8>
111 using type = avx512Long8<std::int64_t>;
113template <>
struct avx512<std::uint64_t, 8>
115 using type = avx512Long8<std::uint64_t>;
117#if defined(__APPLE__)
118template <>
struct avx512<std::size_t, 8>
120 using type = avx512Long8<std::size_t>;
123template <>
struct avx512<std::int32_t, 8>
125 using type = avx2Int8<std::int32_t>;
127template <>
struct avx512<std::uint32_t, 8>
129 using type = avx2Int8<std::uint32_t>;
131template <>
struct avx512<std::int32_t, 16>
133 using type = avx512Int16<std::int32_t>;
135template <>
struct avx512<std::uint32_t, 16>
137 using type = avx512Int16<std::uint32_t>;
140template <>
struct avx512<bool, 8>
142 using type = avx512Mask8;
144template <>
struct avx512<bool, 16>
146 using type = avx512Mask16;
154template <
typename T>
struct avx512Int16
156 static_assert(std::is_integral<T>::value &&
sizeof(T) == 4,
157 "4 bytes Integral required.");
159 static constexpr unsigned int width = 16;
160 static constexpr unsigned int alignment = 64;
162 using scalarType = T;
163 using vectorType = __m512i;
164 using scalarArray = scalarType[width];
170 inline avx512Int16() =
default;
171 inline avx512Int16(
const avx512Int16 &rhs) =
default;
172 inline avx512Int16(
const vectorType &rhs) : _data(rhs)
175 inline avx512Int16(
const scalarType rhs)
177 _data = _mm512_set1_epi32(rhs);
179 explicit inline avx512Int16(scalarArray &rhs)
181 _data = _mm512_load_epi32(rhs);
185 inline avx512Int16 &operator=(
const avx512Int16 &) =
default;
188 inline void store(scalarType *
p)
const
190 _mm512_store_epi32(
p, _data);
193 template <
class flag,
194 typename std::enable_if<is_requiring_alignment<flag>::value &&
195 !is_streaming<flag>::value,
197 inline void store(scalarType *
p, flag)
const
199 _mm512_store_epi32(
p, _data);
202 template <
class flag,
203 typename std::enable_if<!is_requiring_alignment<flag>::value,
205 inline void store(scalarType *
p, flag)
const
207 _mm512_storeu_epi32(
p, _data);
210 inline void load(
const scalarType *
p)
212 _data = _mm512_load_epi32(
p);
215 template <
class flag,
216 typename std::enable_if<is_requiring_alignment<flag>::value &&
217 !is_streaming<flag>::value,
219 inline void load(
const scalarType *
p, flag)
221 _data = _mm512_load_epi32(
p);
224 template <
class flag,
225 typename std::enable_if<!is_requiring_alignment<flag>::value,
227 inline void load(
const scalarType *
p, flag)
234 _data = _mm512_loadu_si512(
p);
237 inline void broadcast(
const scalarType rhs)
239 _data = _mm512_set1_epi32(rhs);
245 inline scalarType operator[](
size_t i)
const
247 alignas(alignment) scalarArray tmp;
252 inline scalarType &operator[](
size_t i)
254 scalarType *tmp =
reinterpret_cast<scalarType *
>(&_data);
260inline avx512Int16<T>
operator+(avx512Int16<T> lhs, avx512Int16<T> rhs)
262 return _mm512_add_epi32(lhs._data, rhs._data);
266 typename T,
typename U,
267 typename =
typename std::enable_if<std::is_arithmetic<U>::value>::type>
268inline avx512Int16<T>
operator+(avx512Int16<T> lhs, U rhs)
270 return _mm512_add_epi32(lhs._data, _mm512_set1_epi32(rhs));
275template <
typename T>
struct avx512Long8
277 static_assert(std::is_integral<T>::value &&
sizeof(T) == 8,
278 "8 bytes Integral required.");
280 static constexpr unsigned int width = 8;
281 static constexpr unsigned int alignment = 64;
283 using scalarType = T;
284 using vectorType = __m512i;
285 using scalarArray = scalarType[width];
291 inline avx512Long8() =
default;
292 inline avx512Long8(
const avx512Long8 &rhs) =
default;
293 inline avx512Long8(
const vectorType &rhs) : _data(rhs)
296 inline avx512Long8(
const scalarType rhs)
298 _data = _mm512_set1_epi64(rhs);
300 explicit inline avx512Long8(scalarArray &rhs)
302 _data = _mm512_load_epi64(rhs);
306 inline avx512Long8 &operator=(
const avx512Long8 &) =
default;
309 inline void store(scalarType *
p)
const
311 _mm512_store_epi64(
p, _data);
314 template <
class flag,
315 typename std::enable_if<is_requiring_alignment<flag>::value &&
316 !is_streaming<flag>::value,
318 inline void store(scalarType *
p, flag)
const
320 _mm512_store_epi64(
p, _data);
323 template <
class flag,
324 typename std::enable_if<!is_requiring_alignment<flag>::value,
326 inline void store(scalarType *
p, flag)
const
328 _mm512_storeu_epi64(
p, _data);
331 inline void load(
const scalarType *
p)
333 _data = _mm512_load_epi64(
p);
336 template <
class flag,
337 typename std::enable_if<is_requiring_alignment<flag>::value &&
338 !is_streaming<flag>::value,
340 inline void load(
const scalarType *
p, flag)
342 _data = _mm512_load_epi64(
p);
345 template <
class flag,
346 typename std::enable_if<!is_requiring_alignment<flag>::value,
348 inline void load(
const scalarType *
p, flag)
355 _data = _mm512_loadu_si512(
p);
358 inline void broadcast(
const scalarType rhs)
360 _data = _mm512_set1_epi64(rhs);
366 inline scalarType operator[](
size_t i)
const
368 alignas(alignment) scalarArray tmp;
373 inline scalarType &operator[](
size_t i)
375 scalarType *tmp =
reinterpret_cast<scalarType *
>(&_data);
381inline avx512Long8<T>
operator+(avx512Long8<T> lhs, avx512Long8<T> rhs)
383 return _mm512_add_epi64(lhs._data, rhs._data);
387 typename T,
typename U,
388 typename =
typename std::enable_if<std::is_arithmetic<U>::value>::type>
389inline avx512Long8<T>
operator+(avx512Long8<T> lhs, U rhs)
391 return _mm512_add_epi64(lhs._data, _mm512_set1_epi64(rhs));
398 static constexpr unsigned int width = 8;
399 static constexpr unsigned int alignment = 64;
401 using scalarType = double;
402 using scalarIndexType = std::uint64_t;
403 using vectorType = __m512d;
404 using scalarArray = scalarType[width];
410 inline avx512Double8() =
default;
411 inline avx512Double8(
const avx512Double8 &rhs) =
default;
412 inline avx512Double8(
const vectorType &rhs) : _data(rhs)
415 inline avx512Double8(
const scalarType rhs)
417 _data = _mm512_set1_pd(rhs);
421 inline avx512Double8 &operator=(
const avx512Double8 &) =
default;
424 inline void store(scalarType *
p)
const
426 _mm512_store_pd(
p, _data);
429 template <
class flag,
430 typename std::enable_if<is_requiring_alignment<flag>::value &&
431 !is_streaming<flag>::value,
433 inline void store(scalarType *
p, flag)
const
435 _mm512_store_pd(
p, _data);
438 template <
class flag,
439 typename std::enable_if<!is_requiring_alignment<flag>::value,
441 inline void store(scalarType *
p, flag)
const
443 _mm512_storeu_pd(
p, _data);
446 template <class flag, typename std::enable_if<is_streaming<flag>::value,
448 inline void store(scalarType *
p, flag)
const
450 _mm512_stream_pd(
p, _data);
454 inline void load(
const scalarType *
p)
456 _data = _mm512_load_pd(
p);
459 template <
class flag,
460 typename std::enable_if<is_requiring_alignment<flag>::value,
462 inline void load(
const scalarType *
p, flag)
464 _data = _mm512_load_pd(
p);
467 template <
class flag,
468 typename std::enable_if<!is_requiring_alignment<flag>::value,
470 inline void load(
const scalarType *
p, flag)
472 _data = _mm512_loadu_pd(
p);
476 inline void broadcast(
const scalarType rhs)
478 _data = _mm512_set1_pd(rhs);
482 template <
typename T>
483 inline void gather(scalarType
const *
p,
const avx2Int8<T> &indices)
485 _data = _mm512_i32gather_pd(indices._data,
p, 8);
488 template <
typename T>
489 inline void scatter(scalarType *out,
const avx2Int8<T> &indices)
const
491 _mm512_i32scatter_pd(out, indices._data, _data, 8);
494 template <
typename T>
495 inline void gather(scalarType
const *
p,
const avx512Long8<T> &indices)
497 _data = _mm512_i64gather_pd(indices._data,
p, 8);
500 template <
typename T>
501 inline void scatter(scalarType *out,
const avx512Long8<T> &indices)
const
503 _mm512_i64scatter_pd(out, indices._data, _data, 8);
508 inline void fma(
const avx512Double8 &a,
const avx512Double8 &b)
510 _data = _mm512_fmadd_pd(a._data, b._data, _data);
516 inline scalarType operator[](
size_t i)
const
518 alignas(alignment) scalarArray tmp;
523 inline scalarType &operator[](
size_t i)
525 scalarType *tmp =
reinterpret_cast<scalarType *
>(&_data);
530 inline void operator+=(avx512Double8 rhs)
532 _data = _mm512_add_pd(_data, rhs._data);
535 inline void operator-=(avx512Double8 rhs)
537 _data = _mm512_sub_pd(_data, rhs._data);
540 inline void operator*=(avx512Double8 rhs)
542 _data = _mm512_mul_pd(_data, rhs._data);
545 inline void operator/=(avx512Double8 rhs)
547 _data = _mm512_div_pd(_data, rhs._data);
551inline avx512Double8
operator+(avx512Double8 lhs, avx512Double8 rhs)
553 return _mm512_add_pd(lhs._data, rhs._data);
556inline avx512Double8
operator-(avx512Double8 lhs, avx512Double8 rhs)
558 return _mm512_sub_pd(lhs._data, rhs._data);
561inline avx512Double8
operator*(avx512Double8 lhs, avx512Double8 rhs)
563 return _mm512_mul_pd(lhs._data, rhs._data);
566inline avx512Double8
operator/(avx512Double8 lhs, avx512Double8 rhs)
568 return _mm512_div_pd(lhs._data, rhs._data);
571inline avx512Double8
sqrt(avx512Double8 in)
573 return _mm512_sqrt_pd(in._data);
576inline avx512Double8
abs(avx512Double8 in)
578 return _mm512_abs_pd(in._data);
581inline avx512Double8
log(avx512Double8 in)
583#if defined(TINYSIMD_HAS_SVML)
584 return _mm512_log_pd(in._data);
588 alignas(avx512Double8::alignment) avx512Double8::scalarArray tmp;
605 const double *in, std::uint32_t dataLen,
606 std::vector<avx512Double8, allocator<avx512Double8>> &out)
609 alignas(avx512Double8::alignment)
610 avx512Double8::scalarIndexType tmp[avx512Double8::width] = {
611 0, dataLen, 2 * dataLen, 3 * dataLen,
612 4 * dataLen, 5 * dataLen, 6 * dataLen, 7 * dataLen};
614 using index_t = avx512Long8<avx512Double8::scalarIndexType>;
616 index_t index1 = index0 + 1;
617 index_t index2 = index0 + 2;
618 index_t index3 = index0 + 3;
621 constexpr uint16_t unrl = 4;
622 size_t nBlocks = dataLen / unrl;
623 for (
size_t i = 0; i < nBlocks; ++i)
625 out[unrl * i + 0].gather(in, index0);
626 out[unrl * i + 1].gather(in, index1);
627 out[unrl * i + 2].gather(in, index2);
628 out[unrl * i + 3].gather(in, index3);
629 index0 = index0 + unrl;
630 index1 = index1 + unrl;
631 index2 = index2 + unrl;
632 index3 = index3 + unrl;
636 for (
size_t i = unrl * nBlocks; i < dataLen; ++i)
638 out[i].gather(in, index0);
644 const std::vector<avx512Double8, allocator<avx512Double8>> &in,
645 std::uint32_t dataLen,
double *out)
649 alignas(avx512Double8::alignment)
650 avx512Double8::scalarIndexType tmp[avx512Double8::width] = {
651 0, dataLen, 2 * dataLen, 3 * dataLen,
652 4 * dataLen, 5 * dataLen, 6 * dataLen, 7 * dataLen};
653 using index_t = avx512Long8<avx512Double8::scalarIndexType>;
655 for (
size_t i = 0; i < dataLen; ++i)
657 in[i].scatter(out, index0);
666 static constexpr unsigned int width = 16;
667 static constexpr unsigned int alignment = 64;
669 using scalarType = float;
670 using scalarIndexType = std::uint32_t;
671 using vectorType = __m512;
672 using scalarArray = scalarType[width];
678 inline avx512Float16() =
default;
679 inline avx512Float16(
const avx512Float16 &rhs) =
default;
680 inline avx512Float16(
const vectorType &rhs) : _data(rhs)
683 inline avx512Float16(
const scalarType rhs)
685 _data = _mm512_set1_ps(rhs);
689 inline avx512Float16 &operator=(
const avx512Float16 &) =
default;
692 inline void store(scalarType *
p)
const
694 _mm512_store_ps(
p, _data);
697 template <
class flag,
698 typename std::enable_if<is_requiring_alignment<flag>::value &&
699 !is_streaming<flag>::value,
701 inline void store(scalarType *
p, flag)
const
703 _mm512_store_ps(
p, _data);
706 template <
class flag,
707 typename std::enable_if<!is_requiring_alignment<flag>::value,
709 inline void store(scalarType *
p, flag)
const
711 _mm512_storeu_ps(
p, _data);
714 template <class flag, typename std::enable_if<is_streaming<flag>::value,
716 inline void store(scalarType *
p, flag)
const
718 _mm512_stream_ps(
p, _data);
722 inline void load(
const scalarType *
p)
724 _data = _mm512_load_ps(
p);
727 template <
class flag,
728 typename std::enable_if<is_requiring_alignment<flag>::value,
730 inline void load(
const scalarType *
p, flag)
732 _data = _mm512_load_ps(
p);
735 template <
class flag,
736 typename std::enable_if<!is_requiring_alignment<flag>::value,
738 inline void load(
const scalarType *
p, flag)
740 _data = _mm512_loadu_ps(
p);
744 inline void broadcast(
const scalarType rhs)
746 _data = _mm512_set1_ps(rhs);
750 template <
typename T>
751 inline void gather(scalarType
const *
p,
const avx512Int16<T> &indices)
753 _data = _mm512_i32gather_ps(indices._data,
p,
sizeof(scalarType));
756 template <
typename T>
757 inline void scatter(scalarType *out,
const avx512Int16<T> &indices)
const
759 _mm512_i32scatter_ps(out, indices._data, _data,
sizeof(scalarType));
764 inline void fma(
const avx512Float16 &a,
const avx512Float16 &b)
766 _data = _mm512_fmadd_ps(a._data, b._data, _data);
772 inline scalarType operator[](
size_t i)
const
774 alignas(alignment) scalarArray tmp;
779 inline scalarType &operator[](
size_t i)
781 scalarType *tmp =
reinterpret_cast<scalarType *
>(&_data);
785 inline void operator+=(avx512Float16 rhs)
787 _data = _mm512_add_ps(_data, rhs._data);
790 inline void operator-=(avx512Float16 rhs)
792 _data = _mm512_sub_ps(_data, rhs._data);
795 inline void operator*=(avx512Float16 rhs)
797 _data = _mm512_mul_ps(_data, rhs._data);
800 inline void operator/=(avx512Float16 rhs)
802 _data = _mm512_div_ps(_data, rhs._data);
806inline avx512Float16
operator+(avx512Float16 lhs, avx512Float16 rhs)
808 return _mm512_add_ps(lhs._data, rhs._data);
811inline avx512Float16
operator-(avx512Float16 lhs, avx512Float16 rhs)
813 return _mm512_sub_ps(lhs._data, rhs._data);
816inline avx512Float16
operator*(avx512Float16 lhs, avx512Float16 rhs)
818 return _mm512_mul_ps(lhs._data, rhs._data);
821inline avx512Float16
operator/(avx512Float16 lhs, avx512Float16 rhs)
823 return _mm512_div_ps(lhs._data, rhs._data);
826inline avx512Float16
sqrt(avx512Float16 in)
828 return _mm512_sqrt_ps(in._data);
831inline avx512Float16
abs(avx512Float16 in)
833 return _mm512_abs_ps(in._data);
836inline avx512Float16
log(avx512Float16 in)
838#if defined(TINYSIMD_HAS_SVML)
839 return _mm512_log_ps(in._data);
843 alignas(avx512Float16::alignment) avx512Float16::scalarArray tmp;
868 const float *in, std::uint32_t dataLen,
869 std::vector<avx512Float16, allocator<avx512Float16>> &out)
872 alignas(avx512Float16::alignment)
873 avx512Float16::scalarIndexType tmp[avx512Float16::width] = {
874 0, dataLen, 2 * dataLen, 3 * dataLen,
875 4 * dataLen, 5 * dataLen, 6 * dataLen, 7 * dataLen};
877 using index_t = avx512Int16<avx512Float16::scalarIndexType>;
879 index_t index1 = index0 + 1;
880 index_t index2 = index0 + 2;
881 index_t index3 = index0 + 3;
884 constexpr uint16_t unrl = 4;
885 size_t nBlocks = dataLen / unrl;
886 for (
size_t i = 0; i < nBlocks; ++i)
888 out[unrl * i + 0].gather(in, index0);
889 out[unrl * i + 1].gather(in, index1);
890 out[unrl * i + 2].gather(in, index2);
891 out[unrl * i + 3].gather(in, index3);
892 index0 = index0 + unrl;
893 index1 = index1 + unrl;
894 index2 = index2 + unrl;
895 index3 = index3 + unrl;
899 for (
size_t i = unrl * nBlocks; i < dataLen; ++i)
901 out[i].gather(in, index0);
907 const std::vector<avx512Float16, allocator<avx512Float16>> &in,
908 std::uint32_t dataLen,
float *out)
912 alignas(avx512Float16::alignment)
913 avx512Float16::scalarIndexType tmp[avx512Float16::width] = {
914 0, dataLen, 2 * dataLen, 3 * dataLen,
915 4 * dataLen, 5 * dataLen, 6 * dataLen, 7 * dataLen};
916 using index_t = avx512Int16<avx512Float16::scalarIndexType>;
919 for (
size_t i = 0; i < dataLen; ++i)
921 in[i].scatter(out, index0);
935struct avx512Mask8 : avx512Long8<std::uint64_t>
938 using avx512Long8::avx512Long8;
940 static constexpr scalarType true_v = -1;
941 static constexpr scalarType false_v = 0;
944inline avx512Mask8
operator>(avx512Double8 lhs, avx512Double8 rhs)
946 __mmask8 mask = _mm512_cmp_pd_mask(lhs._data, rhs._data, _CMP_GT_OQ);
947 return _mm512_maskz_set1_epi64(mask, avx512Mask8::true_v);
950inline bool operator&&(avx512Mask8 lhs,
bool rhs)
952 __m512i val_true = _mm512_set1_epi64(avx512Mask8::true_v);
953 __mmask8 mask = _mm512_test_epi64_mask(lhs._data, val_true);
954 unsigned int tmp = _cvtmask16_u32(mask);
958struct avx512Mask16 : avx512Int16<std::uint32_t>
961 using avx512Int16::avx512Int16;
963 static constexpr scalarType true_v = -1;
964 static constexpr scalarType false_v = 0;
967inline avx512Mask16
operator>(avx512Float16 lhs, avx512Float16 rhs)
969 __mmask16 mask = _mm512_cmp_ps_mask(lhs._data, rhs._data, _CMP_GT_OQ);
970 return _mm512_maskz_set1_epi32(mask, avx512Mask16::true_v);
973inline bool operator&&(avx512Mask16 lhs,
bool rhs)
975 __m512i val_true = _mm512_set1_epi32(avx512Mask16::true_v);
976 __mmask16 mask = _mm512_test_epi32_mask(lhs._data, val_true);
977 unsigned int tmp = _cvtmask16_u32(mask);
void load_interleave(const T *in, size_t dataLen, std::vector< scalarT< T >, allocator< scalarT< T > > > &out)
scalarT< T > abs(scalarT< T > in)
static constexpr struct tinysimd::is_aligned_t is_aligned
scalarT< T > operator-(scalarT< T > lhs, scalarT< T > rhs)
scalarT< T > operator/(scalarT< T > lhs, scalarT< T > rhs)
scalarT< T > log(scalarT< T > in)
scalarT< T > operator*(scalarT< T > lhs, scalarT< T > rhs)
scalarMask operator>(scalarT< double > lhs, scalarT< double > rhs)
bool operator&&(scalarMask lhs, bool rhs)
scalarT< T > sqrt(scalarT< T > in)
void deinterleave_store(const std::vector< scalarT< T >, allocator< scalarT< T > > > &in, size_t dataLen, T *out)
scalarT< T > operator+(scalarT< T > lhs, scalarT< T > rhs)