37#ifndef NEKTAR_LIB_LIBUTILITIES_BASSICUTILS_VMATHSIMD_HPP
38#define NEKTAR_LIB_LIBUTILITIES_BASSICUTILS_VMATHSIMD_HPP
45template <
class T,
typename =
typename std::enable_if<
46 std::is_floating_point<T>::value>::type>
47void Vadd(
const size_t n,
const T *x,
const T *y, T *
z)
54 while (cnt >= 4 * vec_t::width)
57 vec_t yChunk0, yChunk1, yChunk2, yChunk3;
63 vec_t xChunk0, xChunk1, xChunk2, xChunk3;
70 vec_t zChunk0 = xChunk0 + yChunk0;
71 vec_t zChunk1 = xChunk1 + yChunk1;
72 vec_t zChunk2 = xChunk2 + yChunk2;
73 vec_t zChunk3 = xChunk3 + yChunk3;
82 x += 4 * vec_t::width;
83 y += 4 * vec_t::width;
84 z += 4 * vec_t::width;
85 cnt -= 4 * vec_t::width;
89 while (cnt >= 2 * vec_t::width)
92 vec_t yChunk0, yChunk1;
96 vec_t xChunk0, xChunk1;
101 vec_t zChunk0 = xChunk0 + yChunk0;
102 vec_t zChunk1 = xChunk1 + yChunk1;
109 x += 2 * vec_t::width;
110 y += 2 * vec_t::width;
111 z += 2 * vec_t::width;
112 cnt -= 2 * vec_t::width;
116 while (cnt >= vec_t::width)
125 vec_t zChunk = xChunk + yChunk;
151template <
class T,
typename =
typename std::enable_if<
152 std::is_floating_point<T>::value>::type>
153void Vmul(
const size_t n,
const T *x,
const T *y, T *
z)
160 while (cnt >= 4 * vec_t::width)
163 vec_t yChunk0, yChunk1, yChunk2, yChunk3;
169 vec_t xChunk0, xChunk1, xChunk2, xChunk3;
176 vec_t zChunk0 = xChunk0 * yChunk0;
177 vec_t zChunk1 = xChunk1 * yChunk1;
178 vec_t zChunk2 = xChunk2 * yChunk2;
179 vec_t zChunk3 = xChunk3 * yChunk3;
188 x += 4 * vec_t::width;
189 y += 4 * vec_t::width;
190 z += 4 * vec_t::width;
191 cnt -= 4 * vec_t::width;
195 while (cnt >= 2 * vec_t::width)
198 vec_t yChunk0, yChunk1;
202 vec_t xChunk0, xChunk1;
207 vec_t zChunk0 = xChunk0 * yChunk0;
208 vec_t zChunk1 = xChunk1 * yChunk1;
215 x += 2 * vec_t::width;
216 y += 2 * vec_t::width;
217 z += 2 * vec_t::width;
218 cnt -= 2 * vec_t::width;
222 while (cnt >= vec_t::width)
231 vec_t zChunk = xChunk * yChunk;
257template <
class T,
typename =
typename std::enable_if<
258 std::is_floating_point<T>::value>::type>
259void Vvtvp(
const size_t n,
const T *
w,
const T *x,
const T *y, T *
z)
266 while (cnt >= vec_t::width)
277 vec_t zChunk = wChunk * xChunk + yChunk;
294 *
z = (*w) * (*x) + (*y);
305template <
class T,
typename =
typename std::enable_if<
306 std::is_floating_point<T>::value>::type>
307void Vvtvm(
const size_t n,
const T *
w,
const T *x,
const T *y, T *
z)
314 while (cnt >= vec_t::width)
325 vec_t zChunk = wChunk * xChunk - yChunk;
342 *
z = (*w) * (*x) - (*y);
354template <
class T,
typename =
typename std::enable_if<
355 std::is_floating_point<T>::value>::type>
356inline void Vvtvvtp(
const size_t n,
const T *v,
const T *
w,
const T *x,
364 while (cnt >= vec_t::width)
377 vec_t z1Chunk = vChunk * wChunk;
378 vec_t z2Chunk = xChunk * yChunk;
379 vec_t zChunk = z1Chunk + z2Chunk;
412template <
class T,
typename =
typename std::enable_if<
413 std::is_floating_point<T>::value>::type>
414inline void Vvtvvtm(
const size_t n,
const T *v,
const T *
w,
const T *x,
422 while (cnt >= vec_t::width)
435 vec_t z1Chunk = vChunk * wChunk;
436 vec_t z2Chunk = xChunk * yChunk;
437 vec_t zChunk = z1Chunk - z2Chunk;
469template <
class T,
class I,
470 typename =
typename std::enable_if<std::is_floating_point<T>::value &&
471 std::is_integral<I>::value>::type>
472void Gathr(
const I n,
const T *x,
const I *y, T *
z)
480 while (cnt >= 4 * vec_t::width)
483 vec_t_i yChunk0, yChunk1, yChunk2, yChunk3;
490 vec_t zChunk0, zChunk1, zChunk2, zChunk3;
491 zChunk0.gather(x, yChunk0);
492 zChunk1.gather(x, yChunk1);
493 zChunk2.gather(x, yChunk2);
494 zChunk3.gather(x, yChunk3);
503 y += 4 * vec_t_i::width;
504 z += 4 * vec_t::width;
505 cnt -= 4 * vec_t::width;
509 while (cnt >= 2 * vec_t::width)
512 vec_t_i yChunk0, yChunk1;
517 vec_t zChunk0, zChunk1;
518 zChunk0.gather(x, yChunk0);
519 zChunk1.gather(x, yChunk1);
526 y += 2 * vec_t_i::width;
527 z += 2 * vec_t::width;
528 cnt -= 2 * vec_t::width;
532 while (cnt >= vec_t::width)
540 zChunk.gather(x, yChunk);
std::vector< double > w(NPUPPER)
std::vector< double > z(NPUPPER)
tinysimd::simd< NekDouble > vec_t
void Vvtvp(const size_t n, const T *w, const T *x, const T *y, T *z)
vvtvp (vector times vector plus vector): z = w*x + y
void Vadd(const size_t n, const T *x, const T *y, T *z)
Add vector z = x + y.
void Vvtvm(const size_t n, const T *w, const T *x, const T *y, T *z)
vvtvm (vector times vector minus vector): z = w*x - y
void Vvtvvtm(const size_t n, const T *v, const T *w, const T *x, const T *y, T *z)
vvtvvtm (vector times vector minus vector times vector):
void Gathr(const I n, const T *x, const I *y, T *z)
Gather vector z[i] = x[y[i]].
void Vvtvvtp(const size_t n, const T *v, const T *w, const T *x, const T *y, T *z)
vvtvvtp (vector times vector plus vector times vector):
void Vmul(const size_t n, const T *x, const T *y, T *z)
Multiply vector z = x * y.
static constexpr struct tinysimd::is_not_aligned_t is_not_aligned
typename abi< ScalarType, width >::type simd