traceax
is a Python library to perform stochastic trace estimation for linear operators. Namely,
given a square linear operator traceax
provides flexible routines that estimate,
using only matrix-vector products. traceax
is heavily inspired by
lineax as well as
XTrace.
Installation | Example | Documentation | Notes | Support | Other Software
Users can download the latest repository and then use pip
:
git clone https://github.com/mancusolab/traceax.git
cd traceax
pip install .
import jax.numpy as jnp
import jax.random as rdm
import lineax as lx
import traceax as tx
# simulate simple symmetric matrix with exponential eigenvalue decay
seed = 0
N = 1000
key = rdm.PRNGKey(seed)
key, xkey = rdm.split(key)
X = rdm.normal(xkey, (N, N))
Q, R = jnp.linalg.qr(X)
U = jnp.power(0.7, jnp.arange(N))
A = (Q * U) @ Q.T
# should be numerically close
print(jnp.trace(A)) # 3.3333323
print(jnp.sum(U)) # 3.3333335
# setup linear operator
operator = lx.MatrixLinearOperator(A)
# number of matrix vector operators
k = 25
# split key for estimators
key, key1, key2, key3, key4 = rdm.split(key, 5)
# Hutchinson estimator; default samples Rademacher {-1,+1}
hutch = tx.HutchinsonEstimator()
print(hutch.estimate(key1, operator, k)) # (Array(3.6007538, dtype=float32), {})
# Hutch++ estimator; default samples Rademacher {-1,+1}
hpp = tx.HutchPlusPlusEstimator()
print(hpp.estimate(key2, operator, k)) # (Array(3.4094956, dtype=float32), {})
# XTrace estimator; default samples uniformly on n-Sphere
xt = tx.XTraceEstimator()
print(xt.estimate(key3, operator, k)) # (Array(3.3030486, dtype=float32), {'std.err': Array(0.01238528, dtype=float32)})
# XNysTrace estimator; Improved performance for NSD/PSD trace estimates
operator = lx.TaggedLinearOperator(operator, lx.positive_semidefinite_tag)
nt = tx.XNysTraceEstimator()
print(nt.estimate(key4, operator, k)) # (Array(3.3314352, dtype=float32), {'std.err': Array(0.0006521, dtype=float32)})
Documentation is available at here.
traceax
uses JAX with Just In Time compilation to achieve high-speed computation. However, there are some issues for JAX with Mac M1 chip. To solve this, users need to initiate conda using miniforge, and then installtraceax
usingpip
in the desired environment.
Please report any bugs or feature requests in the Issue Tracker. If users have any questions or comments, please contact Linda Serafin (lserafin@usc.edu) or Nicholas Mancuso (nmancuso@usc.edu).
Feel free to use other software developed by Mancuso Lab:
- SuShiE: a Bayesian fine-mapping framework for molecular QTL data across multiple ancestries.
- MA-FOCUS: a Bayesian fine-mapping framework using TWAS statistics across multiple ancestries to identify the causal genes for complex traits.
- SuSiE-PCA: a scalable Bayesian variable selection technique for sparse principal component analysis
- twas_sim: a Python software to simulate TWAS statistics.
- FactorGo: a scalable variational factor analysis model that learns pleiotropic factors from GWAS summary statistics.
- HAMSTA: a Python software to estimate heritability explained by local ancestry data from admixture mapping summary statistics.
traceax
is distributed under the terms of the
Apache-2.0 license.
This project has been set up using Hatch. For details and usage information on Hatch see https://github.com/pypa/hatch.