37#ifndef NEKTAR_LIB_LIBUTILITIES_BASSICUTILS_VMATHSIMD_HPP
38#define NEKTAR_LIB_LIBUTILITIES_BASSICUTILS_VMATHSIMD_HPP
47template <
class T,
typename =
typename std::enable_if<
48 std::is_floating_point<T>::value>::type>
49void Vadd(
const size_t n,
const T *x,
const T *y, T *
z)
56 while (cnt >= 4 * vec_t::width)
59 vec_t yChunk0, yChunk1, yChunk2, yChunk3;
65 vec_t xChunk0, xChunk1, xChunk2, xChunk3;
72 vec_t zChunk0 = xChunk0 + yChunk0;
73 vec_t zChunk1 = xChunk1 + yChunk1;
74 vec_t zChunk2 = xChunk2 + yChunk2;
75 vec_t zChunk3 = xChunk3 + yChunk3;
84 x += 4 * vec_t::width;
85 y += 4 * vec_t::width;
86 z += 4 * vec_t::width;
87 cnt -= 4 * vec_t::width;
91 while (cnt >= 2 * vec_t::width)
94 vec_t yChunk0, yChunk1;
98 vec_t xChunk0, xChunk1;
103 vec_t zChunk0 = xChunk0 + yChunk0;
104 vec_t zChunk1 = xChunk1 + yChunk1;
111 x += 2 * vec_t::width;
112 y += 2 * vec_t::width;
113 z += 2 * vec_t::width;
114 cnt -= 2 * vec_t::width;
118 while (cnt >= vec_t::width)
127 vec_t zChunk = xChunk + yChunk;
153template <
class T,
typename =
typename std::enable_if<
154 std::is_floating_point<T>::value>::type>
155void Vmul(
const size_t n,
const T *x,
const T *y, T *
z)
162 while (cnt >= 4 * vec_t::width)
165 vec_t yChunk0, yChunk1, yChunk2, yChunk3;
171 vec_t xChunk0, xChunk1, xChunk2, xChunk3;
178 vec_t zChunk0 = xChunk0 * yChunk0;
179 vec_t zChunk1 = xChunk1 * yChunk1;
180 vec_t zChunk2 = xChunk2 * yChunk2;
181 vec_t zChunk3 = xChunk3 * yChunk3;
190 x += 4 * vec_t::width;
191 y += 4 * vec_t::width;
192 z += 4 * vec_t::width;
193 cnt -= 4 * vec_t::width;
197 while (cnt >= 2 * vec_t::width)
200 vec_t yChunk0, yChunk1;
204 vec_t xChunk0, xChunk1;
209 vec_t zChunk0 = xChunk0 * yChunk0;
210 vec_t zChunk1 = xChunk1 * yChunk1;
217 x += 2 * vec_t::width;
218 y += 2 * vec_t::width;
219 z += 2 * vec_t::width;
220 cnt -= 2 * vec_t::width;
224 while (cnt >= vec_t::width)
233 vec_t zChunk = xChunk * yChunk;
259template <
class T,
typename =
typename std::enable_if<
260 std::is_floating_point<T>::value>::type>
261void Vvtvp(
const size_t n,
const T *
w,
const T *x,
const T *y, T *
z)
268 while (cnt >= vec_t::width)
279 vec_t zChunk = wChunk * xChunk + yChunk;
296 *
z = (*w) * (*x) + (*y);
307template <
class T,
typename =
typename std::enable_if<
308 std::is_floating_point<T>::value>::type>
309void Vvtvm(
const size_t n,
const T *
w,
const T *x,
const T *y, T *
z)
316 while (cnt >= vec_t::width)
327 vec_t zChunk = wChunk * xChunk - yChunk;
344 *
z = (*w) * (*x) - (*y);
356template <
class T,
typename =
typename std::enable_if<
357 std::is_floating_point<T>::value>::type>
358inline void Vvtvvtp(
const size_t n,
const T *v,
const T *
w,
const T *x,
366 while (cnt >= vec_t::width)
379 vec_t z1Chunk = vChunk * wChunk;
380 vec_t z2Chunk = xChunk * yChunk;
381 vec_t zChunk = z1Chunk + z2Chunk;
414template <
class T,
typename =
typename std::enable_if<
415 std::is_floating_point<T>::value>::type>
416inline void Vvtvvtm(
const size_t n,
const T *v,
const T *
w,
const T *x,
424 while (cnt >= vec_t::width)
437 vec_t z1Chunk = vChunk * wChunk;
438 vec_t z2Chunk = xChunk * yChunk;
439 vec_t zChunk = z1Chunk - z2Chunk;
471template <
class T,
class I,
472 typename =
typename std::enable_if<std::is_floating_point<T>::value &&
473 std::is_integral<I>::value>::type>
474void Gathr(
const I n,
const T *x,
const I *y, T *
z)
482 while (cnt >= 4 * vec_t::width)
485 vec_t_i yChunk0, yChunk1, yChunk2, yChunk3;
492 vec_t zChunk0, zChunk1, zChunk2, zChunk3;
493 zChunk0.gather(x, yChunk0);
494 zChunk1.gather(x, yChunk1);
495 zChunk2.gather(x, yChunk2);
496 zChunk3.gather(x, yChunk3);
505 y += 4 * vec_t_i::width;
506 z += 4 * vec_t::width;
507 cnt -= 4 * vec_t::width;
511 while (cnt >= 2 * vec_t::width)
514 vec_t_i yChunk0, yChunk1;
519 vec_t zChunk0, zChunk1;
520 zChunk0.gather(x, yChunk0);
521 zChunk1.gather(x, yChunk1);
528 y += 2 * vec_t_i::width;
529 z += 2 * vec_t::width;
530 cnt -= 2 * vec_t::width;
534 while (cnt >= vec_t::width)
542 zChunk.gather(x, yChunk);
std::vector< double > w(NPUPPER)
std::vector< double > z(NPUPPER)
tinysimd::simd< NekDouble > vec_t
void Vvtvp(const size_t n, const T *w, const T *x, const T *y, T *z)
vvtvp (vector times vector plus vector): z = w*x + y
void Vadd(const size_t n, const T *x, const T *y, T *z)
Add vector z = x + y.
void Vvtvm(const size_t n, const T *w, const T *x, const T *y, T *z)
vvtvm (vector times vector minus vector): z = w*x - y
void Vvtvvtm(const size_t n, const T *v, const T *w, const T *x, const T *y, T *z)
vvtvvtm (vector times vector minus vector times vector):
void Gathr(const I n, const T *x, const I *y, T *z)
Gather vector z[i] = x[y[i]].
void Vvtvvtp(const size_t n, const T *v, const T *w, const T *x, const T *y, T *z)
vvtvvtp (vector times vector plus vector times vector):
void Vmul(const size_t n, const T *x, const T *y, T *z)
Multiply vector z = x * y.
static constexpr struct tinysimd::is_not_aligned_t is_not_aligned
typename abi< ScalarType, width >::type simd