Skip to content

Commit

Permalink
Merge pull request xtensor-stack#962 from anutosh491/remaining_ops_impl
Browse files Browse the repository at this point in the history
Implemented few operations for the wasm instruction set
  • Loading branch information
JohanMabille authored Nov 2, 2023
2 parents 46c561b + b816668 commit 105658a
Showing 1 changed file with 180 additions and 6 deletions.
186 changes: 180 additions & 6 deletions include/xsimd/arch/xsimd_wasm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,25 @@

namespace xsimd
{
template <class batch_type, bool... Values>
struct batch_bool_constant;

template <class T_out, class T_in, class A>
inline batch<T_out, A> bitwise_cast(batch<T_in, A> const& x) noexcept;

template <class batch_type, typename batch_type::value_type... Values>
struct batch_constant;

namespace kernel
{
using namespace types;

// fwd
template <class A, class T, size_t I>
inline batch<T, A> insert(batch<T, A> const& self, T val, index<I>, requires_arch<generic>) noexcept;
template <class A, typename T, typename ITy, ITy... Indices>
inline batch<T, A> shuffle(batch<T, A> const& x, batch<T, A> const& y, batch_constant<batch<ITy, A>, Indices...>, requires_arch<generic>) noexcept;

// abs
template <class A, class T, typename std::enable_if<std::is_integral<T>::value && std::is_signed<T>::value, void>::type>
inline batch<T, A> abs(batch<T, A> const& self, requires_arch<wasm>) noexcept
Expand Down Expand Up @@ -136,6 +150,13 @@ namespace xsimd
return wasm_i8x16_bitmask(self) != 0;
}

// batch_bool_cast
template <class A, class T_out, class T_in>
inline batch_bool<T_out, A> batch_bool_cast(batch_bool<T_in, A> const& self, batch_bool<T_out, A> const&, requires_arch<wasm>) noexcept
{
return { bitwise_cast<T_out>(batch<T_in, A>(self.data)).data };
}

// bitwise_and
template <class A, class T>
inline batch<T, A> bitwise_and(batch<T, A> const& self, batch<T, A> const& other, requires_arch<wasm>) noexcept
Expand All @@ -162,6 +183,13 @@ namespace xsimd
return wasm_v128_andnot(self, other);
}

// bitwise_cast
template <class A, class T, class Tp>
inline batch<Tp, A> bitwise_cast(batch<T, A> const& self, batch<Tp, A> const&, requires_arch<wasm>) noexcept
{
return batch<Tp, A>(self.data);
}

// bitwise_or
template <class A, class T>
inline batch<T, A> bitwise_or(batch<T, A> const& self, batch<T, A> const& other, requires_arch<wasm>) noexcept
Expand Down Expand Up @@ -415,6 +443,53 @@ namespace xsimd
return wasm_f64x2_eq(self, other);
}

// fast_cast
namespace detail
{
template <class A>
inline batch<float, A> fast_cast(batch<int32_t, A> const& self, batch<float, A> const&, requires_arch<wasm>) noexcept
{
return wasm_f32x4_convert_i32x4(self);
}

template <class A>
inline batch<double, A> fast_cast(batch<uint64_t, A> const& x, batch<double, A> const&, requires_arch<wasm>) noexcept
{
// from https://stackoverflow.com/questions/41144668/how-to-efficiently-perform-double-int64-conversions-with-sse-avx
// adapted to wasm
v128_t xH = wasm_u64x2_shr(x, 32);
xH = wasm_v128_or(xH, wasm_f64x2_splat(19342813113834066795298816.)); // 2^84
v128_t mask = wasm_i16x8_make(0xFFFF, 0xFFFF, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0x0000, 0x0000);
v128_t xL = wasm_v128_or(wasm_v128_and(mask, x), wasm_v128_andnot(wasm_f64x2_splat(0x0010000000000000), mask)); // 2^52
v128_t f = wasm_f64x2_sub(xH, wasm_f64x2_splat(19342813118337666422669312.)); // 2^84 + 2^52
return wasm_f64x2_add(f, xL);
}

template <class A>
inline batch<double, A> fast_cast(batch<int64_t, A> const& x, batch<double, A> const&, requires_arch<wasm>) noexcept
{
// from https://stackoverflow.com/questions/41144668/how-to-efficiently-perform-double-int64-conversions-with-sse-avx
// adapted to wasm
v128_t xH = wasm_i32x4_shr(x, 16);
xH = wasm_v128_and(xH, wasm_i16x8_make(0x0000, 0x0000, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0xFFFF, 0xFFFF));
xH = wasm_i64x2_add(xH, wasm_f64x2_splat(442721857769029238784.)); // 3*2^67
v128_t mask = wasm_i16x8_make(0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000);
v128_t xL = wasm_v128_or(wasm_v128_and(mask, x), wasm_v128_andnot(wasm_f64x2_splat(0x0010000000000000), mask)); // 2^52
v128_t f = wasm_f64x2_sub(xH, wasm_f64x2_splat(442726361368656609280.)); // 3*2^67 + 2^52
return wasm_f64x2_add(f, xL);
}

template <class A>
inline batch<int32_t, A> fast_cast(batch<float, A> const& self, batch<int32_t, A> const&, requires_arch<wasm>) noexcept
{
return wasm_i32x4_make(
static_cast<int32_t>(wasm_f32x4_extract_lane(self, 0)),
static_cast<int32_t>(wasm_f32x4_extract_lane(self, 1)),
static_cast<int32_t>(wasm_f32x4_extract_lane(self, 2)),
static_cast<int32_t>(wasm_f32x4_extract_lane(self, 3)));
}
}

// floor
template <class A>
inline batch<float, A> floor(batch<float, A> const& self, requires_arch<wasm>) noexcept
Expand Down Expand Up @@ -516,11 +591,11 @@ namespace xsimd
}
else XSIMD_IF_CONSTEXPR(sizeof(T) == 4)
{
return from_mask(batch_bool<float, A> {}, mask, wasm {});
return batch_bool_cast<T>(from_mask(batch_bool<float, A> {}, mask, wasm {}));
}
else XSIMD_IF_CONSTEXPR(sizeof(T) == 8)
{
return from_mask(batch_bool<double, A> {}, mask, wasm {});
return batch_bool_cast<T>(from_mask(batch_bool<double, A> {}, mask, wasm {}));
}
}

Expand Down Expand Up @@ -1039,6 +1114,44 @@ namespace xsimd
return wasm_f64x2_extract_lane(tmp2, 0);
}

// reduce_max
template <class A, class T, class _ = typename std::enable_if<(sizeof(T) <= 2), void>::type>
inline T reduce_max(batch<T, A> const& self, requires_arch<wasm>) noexcept
{
batch<T, A> step0 = wasm_i32x4_shuffle(self, wasm_i32x4_splat(0), 2, 3, 0, 0);
batch<T, A> acc0 = max(self, step0);

batch<T, A> step1 = wasm_i32x4_shuffle(self, wasm_i32x4_splat(0), 1, 0, 0, 0);
batch<T, A> acc1 = max(acc0, step1);

batch<T, A> step2 = wasm_i16x8_shuffle(acc1, wasm_i16x8_splat(0), 1, 0, 0, 0, 4, 5, 6, 7);
batch<T, A> acc2 = max(acc1, step2);
if (sizeof(T) == 2)
return acc2.get(0);
batch<T, A> step3 = bitwise_cast<T>(bitwise_cast<uint16_t>(acc2) >> 8);
batch<T, A> acc3 = max(acc2, step3);
return acc3.get(0);
}

// reduce_min
template <class A, class T, class _ = typename std::enable_if<(sizeof(T) <= 2), void>::type>
inline T reduce_min(batch<T, A> const& self, requires_arch<wasm>) noexcept
{
batch<T, A> step0 = wasm_i32x4_shuffle(self, wasm_i32x4_splat(0), 2, 3, 0, 0);
batch<T, A> acc0 = min(self, step0);

batch<T, A> step1 = wasm_i32x4_shuffle(self, wasm_i32x4_splat(0), 1, 0, 0, 0);
batch<T, A> acc1 = min(acc0, step1);

batch<T, A> step2 = wasm_i16x8_shuffle(acc1, wasm_i16x8_splat(0), 1, 0, 0, 0, 4, 5, 6, 7);
batch<T, A> acc2 = min(acc1, step2);
if (sizeof(T) == 2)
return acc2.get(0);
batch<T, A> step3 = bitwise_cast<T>(bitwise_cast<uint16_t>(acc2) >> 8);
batch<T, A> acc3 = min(acc2, step3);
return acc3.get(0);
}

// rsqrt
template <class A>
inline batch<float, A> rsqrt(batch<float, A> const& self, requires_arch<wasm>) noexcept
Expand Down Expand Up @@ -1144,6 +1257,33 @@ namespace xsimd
return wasm_v128_or(wasm_v128_and(cond, true_br), wasm_v128_andnot(false_br, cond));
}

// shuffle
template <class A, class ITy, ITy I0, ITy I1, ITy I2, ITy I3>
inline batch<float, A> shuffle(batch<float, A> const& x, batch<float, A> const& y, batch_constant<batch<ITy, A>, I0, I1, I2, I3> mask, requires_arch<wasm>) noexcept
{
// shuffle within lane
if (I0 < 4 && I1 < 4 && I2 >= 4 && I3 >= 4)
return wasm_i32x4_shuffle(x, y, I0, I1, I2, I3);

// shuffle within opposite lane
if (I0 >= 4 && I1 >= 4 && I2 < 4 && I3 < 4)
return wasm_i32x4_shuffle(y, x, I0, I1, I2, I3);
return shuffle(x, y, mask, generic {});
}

template <class A, class ITy, ITy I0, ITy I1>
inline batch<double, A> shuffle(batch<double, A> const& x, batch<double, A> const& y, batch_constant<batch<ITy, A>, I0, I1> mask, requires_arch<wasm>) noexcept
{
// shuffle within lane
if (I0 < 2 && I1 >= 2)
return wasm_i64x2_shuffle(x, y, I0, I1);

// shuffle within opposite lane
if (I0 >= 2 && I1 < 2)
return wasm_i64x2_shuffle(y, x, I0, I1);
return shuffle(x, y, mask, generic {});
}

// set
template <class A, class... Values>
inline batch<float, A> set(batch<float, A> const&, requires_arch<wasm>, Values... values) noexcept
Expand Down Expand Up @@ -1243,25 +1383,21 @@ namespace xsimd
template <class A>
inline void store_aligned(float* mem, batch<float, A> const& self, requires_arch<wasm>) noexcept
{
// Assuming that mem is aligned properly, you can use wasm_v128_store to store the batch.
return wasm_v128_store(mem, self);
}
template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value, void>::type>
inline void store_aligned(T* mem, batch<T, A> const& self, requires_arch<wasm>) noexcept
{
// Assuming that mem is aligned properly, you can use wasm_v128_store to store the batch.
return wasm_v128_store((v128_t*)mem, self);
}
template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value, void>::type>
inline void store_aligned(T* mem, batch_bool<T, A> const& self, requires_arch<wasm>) noexcept
{
// Assuming that mem is aligned properly, you can use wasm_v128_store to store the batch.
return wasm_v128_store((v128_t*)mem, self);
}
template <class A>
inline void store_aligned(double* mem, batch<double, A> const& self, requires_arch<wasm>) noexcept
{
// Assuming that mem is aligned properly, you can use wasm_v128_store to store the batch.
return wasm_v128_store(mem, self);
}

Expand Down Expand Up @@ -1363,6 +1499,44 @@ namespace xsimd
return wasm_f64x2_sqrt(val);
}

// swizzle

template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3>
inline batch<float, A> swizzle(batch<float, A> const& self, batch_constant<batch<uint32_t, A>, V0, V1, V2, V3>, requires_arch<wasm>) noexcept
{
return wasm_i32x4_shuffle(self, self, V0, V1, V2, V3);
}

template <class A, uint64_t V0, uint64_t V1>
inline batch<double, A> swizzle(batch<double, A> const& self, batch_constant<batch<uint64_t, A>, V0, V1>, requires_arch<wasm>) noexcept
{
return wasm_i64x2_shuffle(self, self, V0, V1);
}

template <class A, uint64_t V0, uint64_t V1>
inline batch<uint64_t, A> swizzle(batch<uint64_t, A> const& self, batch_constant<batch<uint64_t, A>, V0, V1>, requires_arch<wasm>) noexcept
{
return wasm_i32x4_shuffle(self, wasm_i32x4_splat(0), 2 * V0, 2 * V0 + 1, 2 * V1, 2 * V1 + 1);
}

template <class A, uint64_t V0, uint64_t V1>
inline batch<int64_t, A> swizzle(batch<int64_t, A> const& self, batch_constant<batch<uint64_t, A>, V0, V1> mask, requires_arch<wasm>) noexcept
{
return bitwise_cast<int64_t>(swizzle(bitwise_cast<uint64_t>(self), mask, wasm {}));
}

template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3>
inline batch<uint32_t, A> swizzle(batch<uint32_t, A> const& self, batch_constant<batch<uint32_t, A>, V0, V1, V2, V3>, requires_arch<wasm>) noexcept
{
return wasm_i32x4_shuffle(self, wasm_i32x4_splat(0), V0, V1, V2, V3);
}

template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3>
inline batch<int32_t, A> swizzle(batch<int32_t, A> const& self, batch_constant<batch<uint32_t, A>, V0, V1, V2, V3> mask, requires_arch<wasm>) noexcept
{
return bitwise_cast<int32_t>(swizzle(bitwise_cast<uint32_t>(self), mask, wasm {}));
}

// trunc
template <class A>
inline batch<float, A> trunc(batch<float, A> const& self, requires_arch<wasm>) noexcept
Expand Down

0 comments on commit 105658a

Please sign in to comment.