Skip to content

Commit

Permalink
Merge pull request #401 from rust-lang/std_float_improvements
Browse files Browse the repository at this point in the history
Test StdFloat
  • Loading branch information
calebzulawski authored Mar 6, 2024
2 parents eea6f77 + 278eb28 commit 5794c83
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 116 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions crates/std_float/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ edition = "2021"
[dependencies]
core_simd = { path = "../core_simd", default-features = false }

[dev-dependencies.test_helpers]
path = "../test_helpers"

[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
wasm-bindgen = "0.2"
wasm-bindgen-test = "0.3"

[features]
default = ["as_crate"]
as_crate = []
206 changes: 90 additions & 116 deletions crates/std_float/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#![cfg_attr(feature = "as_crate", no_std)] // We are std!
#![cfg_attr(
feature = "as_crate",
feature(core_intrinsics),
Expand Down Expand Up @@ -44,7 +43,7 @@ use crate::sealed::Sealed;
/// For now this trait is available to permit experimentation with SIMD float
/// operations that may lack hardware support, such as `mul_add`.
pub trait StdFloat: Sealed + Sized {
/// Fused multiply-add. Computes `(self * a) + b` with only one rounding error,
/// Elementwise fused multiply-add. Computes `(self * a) + b` with only one rounding error,
/// yielding a more accurate result than an unfused multiply-add.
///
/// Using `mul_add` *may* be more performant than an unfused multiply-add if the target
Expand All @@ -57,78 +56,65 @@ pub trait StdFloat: Sealed + Sized {
unsafe { intrinsics::simd_fma(self, a, b) }
}

/// Produces a vector where every lane has the square root value
/// of the equivalently-indexed lane in `self`
/// Produces a vector where every element has the square root value
/// of the equivalently-indexed element in `self`
#[inline]
#[must_use = "method returns a new vector and does not mutate the original value"]
fn sqrt(self) -> Self {
unsafe { intrinsics::simd_fsqrt(self) }
}

/// Produces a vector where every lane has the sine of the value
/// in the equivalently-indexed lane in `self`.
#[inline]
/// Produces a vector where every element has the sine of the value
/// in the equivalently-indexed element in `self`.
#[must_use = "method returns a new vector and does not mutate the original value"]
fn sin(self) -> Self {
unsafe { intrinsics::simd_fsin(self) }
}
fn sin(self) -> Self;

/// Produces a vector where every lane has the cosine of the value
/// in the equivalently-indexed lane in `self`.
#[inline]
/// Produces a vector where every element has the cosine of the value
/// in the equivalently-indexed element in `self`.
#[must_use = "method returns a new vector and does not mutate the original value"]
fn cos(self) -> Self {
unsafe { intrinsics::simd_fcos(self) }
}
fn cos(self) -> Self;

/// Produces a vector where every lane has the exponential (base e) of the value
/// in the equivalently-indexed lane in `self`.
#[inline]
/// Produces a vector where every element has the exponential (base e) of the value
/// in the equivalently-indexed element in `self`.
#[must_use = "method returns a new vector and does not mutate the original value"]
fn exp(self) -> Self {
unsafe { intrinsics::simd_fexp(self) }
}
fn exp(self) -> Self;

/// Produces a vector where every lane has the exponential (base 2) of the value
/// in the equivalently-indexed lane in `self`.
#[inline]
/// Produces a vector where every element has the exponential (base 2) of the value
/// in the equivalently-indexed element in `self`.
#[must_use = "method returns a new vector and does not mutate the original value"]
fn exp2(self) -> Self {
unsafe { intrinsics::simd_fexp2(self) }
}
fn exp2(self) -> Self;

/// Produces a vector where every lane has the natural logarithm of the value
/// in the equivalently-indexed lane in `self`.
#[inline]
/// Produces a vector where every element has the natural logarithm of the value
/// in the equivalently-indexed element in `self`.
#[must_use = "method returns a new vector and does not mutate the original value"]
fn log(self) -> Self {
unsafe { intrinsics::simd_flog(self) }
}
fn ln(self) -> Self;

/// Produces a vector where every lane has the base-2 logarithm of the value
/// in the equivalently-indexed lane in `self`.
/// Produces a vector where every element has the logarithm with respect to an arbitrary
/// in the equivalently-indexed elements in `self` and `base`.
#[inline]
#[must_use = "method returns a new vector and does not mutate the original value"]
fn log2(self) -> Self {
unsafe { intrinsics::simd_flog2(self) }
fn log(self, base: Self) -> Self {
unsafe { intrinsics::simd_div(self.ln(), base.ln()) }
}

/// Produces a vector where every lane has the base-10 logarithm of the value
/// in the equivalently-indexed lane in `self`.
#[inline]
/// Produces a vector where every element has the base-2 logarithm of the value
/// in the equivalently-indexed element in `self`.
#[must_use = "method returns a new vector and does not mutate the original value"]
fn log10(self) -> Self {
unsafe { intrinsics::simd_flog10(self) }
}
fn log2(self) -> Self;

/// Produces a vector where every element has the base-10 logarithm of the value
/// in the equivalently-indexed element in `self`.
#[must_use = "method returns a new vector and does not mutate the original value"]
fn log10(self) -> Self;

/// Returns the smallest integer greater than or equal to each lane.
/// Returns the smallest integer greater than or equal to each element.
#[must_use = "method returns a new vector and does not mutate the original value"]
#[inline]
fn ceil(self) -> Self {
unsafe { intrinsics::simd_ceil(self) }
}

/// Returns the largest integer value less than or equal to each lane.
/// Returns the largest integer value less than or equal to each element.
#[must_use = "method returns a new vector and does not mutate the original value"]
#[inline]
fn floor(self) -> Self {
Expand Down Expand Up @@ -157,77 +143,65 @@ pub trait StdFloat: Sealed + Sized {
impl<const N: usize> Sealed for Simd<f32, N> where LaneCount<N>: SupportedLaneCount {}
impl<const N: usize> Sealed for Simd<f64, N> where LaneCount<N>: SupportedLaneCount {}

// We can safely just use all the defaults.
impl<const N: usize> StdFloat for Simd<f32, N>
where
LaneCount<N>: SupportedLaneCount,
{
/// Returns the floating point's fractional value, with its integer part removed.
#[must_use = "method returns a new vector and does not mutate the original value"]
#[inline]
fn fract(self) -> Self {
self - self.trunc()
macro_rules! impl_float {
{
$($fn:ident: $intrinsic:ident,)*
} => {
impl<const N: usize> StdFloat for Simd<f32, N>
where
LaneCount<N>: SupportedLaneCount,
{
#[inline]
fn fract(self) -> Self {
self - self.trunc()
}

$(
#[inline]
fn $fn(self) -> Self {
unsafe { intrinsics::$intrinsic(self) }
}
)*
}

impl<const N: usize> StdFloat for Simd<f64, N>
where
LaneCount<N>: SupportedLaneCount,
{
#[inline]
fn fract(self) -> Self {
self - self.trunc()
}

$(
#[inline]
fn $fn(self) -> Self {
// https://github.com/llvm/llvm-project/issues/83729
#[cfg(target_arch = "aarch64")]
{
let mut ln = Self::splat(0f64);
for i in 0..N {
ln[i] = self[i].$fn()
}
ln
}

#[cfg(not(target_arch = "aarch64"))]
{
unsafe { intrinsics::$intrinsic(self) }
}
}
)*
}
}
}

impl<const N: usize> StdFloat for Simd<f64, N>
where
LaneCount<N>: SupportedLaneCount,
{
/// Returns the floating point's fractional value, with its integer part removed.
#[must_use = "method returns a new vector and does not mutate the original value"]
#[inline]
fn fract(self) -> Self {
self - self.trunc()
}
}

#[cfg(test)]
mod tests_simd_floats {
use super::*;
use simd::prelude::*;

#[test]
fn everything_works_f32() {
let x = f32x4::from_array([0.1, 0.5, 0.6, -1.5]);

let x2 = x + x;
let _xc = x.ceil();
let _xf = x.floor();
let _xr = x.round();
let _xt = x.trunc();
let _xfma = x.mul_add(x, x);
let _xsqrt = x.sqrt();
let _abs_mul = x2.abs() * x2;

let _fexp = x.exp();
let _fexp2 = x.exp2();
let _flog = x.log();
let _flog2 = x.log2();
let _flog10 = x.log10();
let _fsin = x.sin();
let _fcos = x.cos();
}

#[test]
fn everything_works_f64() {
let x = f64x4::from_array([0.1, 0.5, 0.6, -1.5]);

let x2 = x + x;
let _xc = x.ceil();
let _xf = x.floor();
let _xr = x.round();
let _xt = x.trunc();
let _xfma = x.mul_add(x, x);
let _xsqrt = x.sqrt();
let _abs_mul = x2.abs() * x2;

let _fexp = x.exp();
let _fexp2 = x.exp2();
let _flog = x.log();
let _flog2 = x.log2();
let _flog10 = x.log10();
let _fsin = x.sin();
let _fcos = x.cos();
}
impl_float! {
sin: simd_fsin,
cos: simd_fcos,
exp: simd_fexp,
exp2: simd_fexp2,
ln: simd_flog,
log2: simd_flog2,
log10: simd_flog10,
}
74 changes: 74 additions & 0 deletions crates/std_float/tests/float.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#![feature(portable_simd)]

macro_rules! unary_test {
{ $scalar:tt, $($func:tt),+ } => {
test_helpers::test_lanes! {
$(
fn $func<const LANES: usize>() {
test_helpers::test_unary_elementwise(
&core_simd::simd::Simd::<$scalar, LANES>::$func,
&$scalar::$func,
&|_| true,
)
}
)*
}
}
}

macro_rules! binary_test {
{ $scalar:tt, $($func:tt),+ } => {
test_helpers::test_lanes! {
$(
fn $func<const LANES: usize>() {
test_helpers::test_binary_elementwise(
&core_simd::simd::Simd::<$scalar, LANES>::$func,
&$scalar::$func,
&|_, _| true,
)
}
)*
}
}
}

macro_rules! ternary_test {
{ $scalar:tt, $($func:tt),+ } => {
test_helpers::test_lanes! {
$(
fn $func<const LANES: usize>() {
test_helpers::test_ternary_elementwise(
&core_simd::simd::Simd::<$scalar, LANES>::$func,
&$scalar::$func,
&|_, _, _| true,
)
}
)*
}
}
}

macro_rules! impl_tests {
{ $scalar:tt } => {
mod $scalar {
use std_float::StdFloat;

unary_test! { $scalar, sqrt, sin, cos, exp, exp2, ln, log2, log10, ceil, floor, round, trunc }
binary_test! { $scalar, log }
ternary_test! { $scalar, mul_add }

test_helpers::test_lanes! {
fn fract<const LANES: usize>() {
test_helpers::test_unary_elementwise_flush_subnormals(
&core_simd::simd::Simd::<$scalar, LANES>::fract,
&$scalar::fract,
&|_| true,
)
}
}
}
}
}

impl_tests! { f32 }
impl_tests! { f64 }

0 comments on commit 5794c83

Please sign in to comment.