Skip to content

Commit

Permalink
feat: pnt (#35)
Browse files Browse the repository at this point in the history
Adds the `pnt` function
  • Loading branch information
storopoli authored Apr 29, 2024
1 parent b9d4d3c commit 36d7834
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The following functions have been ported:
Distribution | Density | Probability | Quantile | Random Generation
--- | :---: | :---: | :---: | :---:
Normal | `dnorm` | `pnorm` | `qnorm` |
Student's t | `dt` | `pt` | |
Student's t | `dt` | `pt`, `pnt` | |
Beta | | `pbeta` | |
Poisson | `dpois` | | |
Gamma | `dgamma` | `pgamma` |
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ mod nmath;
mod pbeta;
mod pgamma;
mod pnorm;
mod pnt;
mod pt;
mod qnorm;
mod rmath;
Expand Down Expand Up @@ -51,6 +52,7 @@ pub use pbeta::pbeta;
pub use pgamma::log1pmx;
pub use pgamma::logspace_add;
pub use pgamma::pgamma;
pub use pnt::pnt;
pub use pt::pt;
pub use rmath::dnorm;
pub use rmath::pnorm;
Expand Down
9 changes: 9 additions & 0 deletions src/nmath.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,18 @@ pub fn r_finite(x: f64) -> bool {

pub const ML_NAN: f64 = f64::NAN;

/// log(sqrt(pi))
pub const M_LN_SQRT_PI: f64 = 0.572_364_942_924_700_1;
/// log(sqrt(2*pi)) == log(2*pi)/2
pub const M_LN_SQRT_2PI: f64 = 0.918_938_533_204_672_8;

/// for IEEE, DBL_MIN_EXP is -1022 but "effective" is -1074
pub const DBL_MIN_EXP: f64 = f64::MIN_EXP as f64;

/// sqrt(2/pi)
#[allow(non_upper_case_globals)]
pub const M_SQRT_2dPI: f64 = 0.797_884_560_802_865_4;

/// log(1 - exp(x)) in more stable form than log1p(- r_d_qiv(x)) :
pub fn r_log1_exp(x: f64) -> f64 {
if x > -M_LN2 {
Expand Down
213 changes: 213 additions & 0 deletions src/pnt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
#![allow(unused_assignments)]
use crate::dpq::r_dt_0;
use crate::dpq::r_dt_1;
use crate::lgammafn;
use crate::libc::fabs;
use crate::libc::DBL_EPSILON;
use crate::nmath::ml_warn_return_nan;
use crate::nmath::r_finite;
use crate::nmath::M_SQRT_2dPI;
use crate::nmath::DBL_MIN_EXP;
use crate::nmath::M_LN_SQRT_PI;
use crate::pbeta;
use crate::pnorm;
use crate::pt;
use crate::rmath::M_LN2;
use libm::exp;
use libm::expm1;
use libm::fmin;
use libm::log1p;
use libm::pow;
use libm::sqrt;

// NOTE: itrmax and errmax may be changed to suit one's needs.
const ITRMAX: usize = 1_000;
const ERRMAX: f64 = 1e-12;

/// Non-central t distribution
///
/// Algorithm AS 243 Lenth,R.V. (1989). Appl. Statist., Vol.38, 185-189.
///
/// Cumulative probability at t of the non-central t-distribution
/// with df degrees of freedom (may be fractional) and non-centrality
/// parameter delta.
pub fn pnt(t: f64, df: f64, ncp: f64, mut lower_tail: bool, log_p: bool) -> f64 {
// initialize variables
let mut albeta = 0.0;
let mut a = 0.0;
let mut b = 0.0;
let mut del = 0.0;
let mut errbd = 0.0;
let mut lambda = 0.0;
let mut rxb = 0.0;
let mut tt = 0.0;
let mut x = 0.0;
let mut geven = 0.0;
let mut godd = 0.0;
let mut p = 0.0;
let mut q = 0.0;
let mut s = 0.0;
let mut tnc = 0.0;
let mut xeven = 0.0;
let mut xodd = 0.0;
let mut negdel = false;

if df <= 0.0 {
return ml_warn_return_nan();
}

if ncp == 0.0 {
return pt(t, df, lower_tail, log_p);
}

if !r_finite(t) {
if t < 0.0 {
return r_dt_0(lower_tail, log_p);
} else {
return r_dt_1(lower_tail, log_p);
}
}
if t >= 0.0 {
negdel = false;
tt = t;
del = ncp;
} else {
/* We deal quickly with left tail if extreme,
since pt(q, df, ncp) <= pt(0, df, ncp) = \Phi(-ncp) */
if ncp > 40.0 && (!log_p || !lower_tail) {
return r_dt_0(lower_tail, log_p);
}
negdel = true;
tt = -t;
del = -ncp;
}

if df > 4e5 || del * del > 2.0 * M_LN2 * (-DBL_MIN_EXP) {
/*-- 2nd part: if del > 37.62, then p=0 below
FIXME: test should depend on `df', `tt' AND `del' ! */
/* Approx. from Abramowitz & Stegun 26.7.10 (p.949) */

let s = 1.0 / (4.0 * df);

return pnorm(
tt * (1.0 - s),
del,
sqrt(1.0 + tt * tt * 2.0 * s),
lower_tail != negdel,
log_p,
);
}

/* initialize twin series */
/* Guenther, J. (1978). Statist. Computn. Simuln. vol.6, 199. */

x = t * t;
rxb = df / (x + df); /* := (1 - x) {x below} -- but more accurately */
x = x / (x + df); /* in [0,1) */

if x > 0.0 {
// <==> t != 0
lambda = del * del;
p = 0.5 * exp(-0.5 * lambda);

if p == 0.0 {
// underflow
println!("pnt: underflow occurred");
println!("pnt: value out of range");
return r_dt_0(lower_tail, log_p);
}

q = M_SQRT_2dPI * p * del;
s = 0.5 - p;
/* s = 0.5 - p = 0.5*(1 - exp(-.5 L)) = -0.5*expm1(-.5 L)) */
if s < 1e-7 {
s = -0.5 * expm1(-0.5 * lambda);
}
a = 0.5;
b = 0.5 * df;
/* rxb = (1 - x) ^ b [ ~= 1 - b*x for tiny x --> see 'xeven' below]
* where '(1 - x)' =: rxb {accurately!} above */
rxb = pow(rxb, b);
albeta = M_LN_SQRT_PI + lgammafn(b) - lgammafn(0.5 + b);
xodd = pbeta(x, a, b, true, false);
godd = 2.0 * rxb * exp(a * x.ln() - albeta);
tnc = b * x;
xeven = if tnc < DBL_EPSILON { tnc } else { 1.0 - rxb };
geven = tnc * rxb;
tnc = p * xodd + q * xeven;

/* repeat until convergence or iteration limit */
for it in 1..=ITRMAX {
a += 1.0;
xodd -= godd;
xeven -= geven;
godd *= x * (a + b - 1.0) / a;
geven *= x * (a + b - 0.5) / (a + 0.5);
p *= lambda / (2.0 * it as f64);
q *= lambda / (2.0 * it as f64 + 1.0);
tnc += p * xodd + q * xeven;
s -= p;

if s < 1e-10 {
/* happens e.g. for (t,df,ncp)=(40,10,38.5), after 799 it.*/
println!("pnt: full precision may not have been achieved");
finis(del, &mut tnc);
break;
}

if s <= 0.0 && it > 1 {
finis(del, &mut tnc);
break;
}

errbd = 2.0 * s * (xodd - godd);

if fabs(errbd) < ERRMAX {
// convergence
finis(del, &mut tnc);
break;
}
}
// non-convergence:
println!("pnt: convergence failed");
} else {
/* x = t = 0 */
tnc = 0.0;
}

lower_tail = lower_tail != negdel; /* xor */
if tnc > 1.0 - 1e-10 && lower_tail {
println!("pnt: full precision may not have been achieved");
}

r_dt_val(fmin(tnc, 1.0), lower_tail, log_p)
}

// converting the goto statement
fn finis(del: f64, tnc: &mut f64) {
*tnc += pnorm(-del, 0.0, 1.0, true, false);
}

pub fn r_dt_val(x: f64, lower_tail: bool, log_p: bool) -> f64 {
if lower_tail {
r_d_val(x, log_p)
} else {
r_d_clog(x, log_p)
}
}

pub fn r_d_val(x: f64, log_p: bool) -> f64 {
if log_p {
x.ln()
} else {
x
}
}

pub fn r_d_clog(p: f64, log_p: bool) -> f64 {
if log_p {
log1p(-p)
} else {
0.5 - p + 0.5
}
}
60 changes: 60 additions & 0 deletions test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ mod test_math {
pub fn lgammafn_sign(x: f64, sgn: Option<&mut i32>) -> f64;
pub fn pgamma(x: f64, alph: f64, scale: f64, lower_tail: i32, log_p: i32) -> f64;
pub fn pnorm5(x: f64, mu: f64, sigma: f64, lower_tail: i32, log_p: i32) -> f64;
pub fn pnt(t: f64, df: f64, ncp: f64, lower_tail: bool, log_p: bool) -> f64;
pub fn pt(x: f64, n: f64, lower_tail: bool, log_p: bool) -> f64;
pub fn qnorm5(p: f64, mu: f64, sigma: f64, lower_tail: i32, log_p: i32) -> f64;
pub fn sinpi(x: f64) -> f64;
Expand Down Expand Up @@ -317,6 +318,65 @@ mod test_math {
});
}

#[test]
fn test_pnt() {
assert!(abs_diff_eq!(
pnt(0.1, 1.0, 1.0, false, false),
unsafe { c::pnt(0.1, 1.0, 1.0, false, false) },
epsilon = 1e-15
));
assert!(abs_diff_eq!(
pnt(-0.1, 1.0, 1.0, false, false),
unsafe { c::pnt(-0.1, 1.0, 1.0, false, false) },
epsilon = 1e-15
));
assert!(abs_diff_eq!(
pnt(1.0, 1.0, 1.0, false, false),
unsafe { c::pnt(1.0, 1.0, 1.0, false, false) },
epsilon = 1e-15
));
assert!(abs_diff_eq!(
pnt(1.0, 1.0, 1.0, false, true),
unsafe { c::pnt(1.0, 1.0, 1.0, false, true) },
epsilon = 1e-15
));
assert!(abs_diff_eq!(
pnt(1.0, 1.0, 1.0, true, true),
unsafe { c::pnt(1.0, 1.0, 1.0, true, true) },
epsilon = 1e-15
));
assert!(abs_diff_eq!(
pnt(1.0, 10.0, -10.0, false, false),
unsafe { c::pnt(1.0, 10.0, -10.0, false, false) },
epsilon = 1e-15
));
assert!(abs_diff_eq!(
pnt(-1.0, 10.0, -10.0, false, false),
unsafe { c::pnt(-1.0, 10.0, -10.0, false, false) },
epsilon = 1e-15
));
assert!(abs_diff_eq!(
pnt(1.0, 10.0, 10.0, true, false),
unsafe { c::pnt(1.0, 10.0, 10.0, true, false) },
epsilon = 1e-15
));
assert!(abs_diff_eq!(
pnt(-1.0, 10.0, 10.0, true, false),
unsafe { c::pnt(1.0, 10.0, 10.0, true, false) },
epsilon = 1e-15
));
assert!(abs_diff_eq!(
pnt(1.0, 10.0, 10.0, false, true),
unsafe { c::pnt(1.0, 10.0, 10.0, false, true) },
epsilon = 1e-15
));
assert!(abs_diff_eq!(
pnt(-1.0, 10.0, 10.0, false, true),
unsafe { c::pnt(-1.0, 10.0, 10.0, false, true) },
epsilon = 1e-15
));
}

#[test]
fn test_pt() {
assert_eq!(pt(0.1, 1.0, false, false), unsafe {
Expand Down

0 comments on commit 36d7834

Please sign in to comment.