-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fastspecfit at risk of becoming slowspecfit #98
Comments
Here are the results from fitting 50 objects:
As expected, the I/O time becomes negligible (and this is with perlmutter in a degraded state) and the time is entirely dominated by the line-fitting (foremost) and the continuum-fitting (secondarily). Within the continuum-fitting, Zooming into the continuum-fitting: |
@moustakas kudos for shining fastspecfit down to the bone and eliminating all of the bottlenecks besides the actual fitting algorithm. For problems of modest dimension (Ndim<~50-100), the most robust and fastest fitter I've used that implements bounds (kind of by far) is the L-BFGS algorithm, but this algorithm leans hard on accurate gradients because it uses the Hessian to condition the gradient descent. Higher-order derivs can be slow to compute numerically, and it can get really tiresome to check their accuracy numerically, but they come with guaranteed machine-precision "for free" if the likelihood/cost function is implemented in an autodiff library like JAX. Speedup factors can be an order of magnitude or more when deploying algorithms based on higher-order gradients on modern tensorcore GPUs. The level of pain it would require to reimplement the internals of fastspecfit into JAX depends entirely on the nitty gritty details of the algorithms you're using - sometimes this is rather straightforward, sometimes it's a bit of a lift, and sometimes it requires abandoning certain sophisticated iterative algorithms in favor of simpler non-adaptive ones. The DSPS paper has some in-the-weeds pointers for JAX implementations for the case of a traditional SPS approach to the problem. The DSPS library is open-source, but this is really more of a collection of kernels and not yet a well-documented library so I imagine it might not be very transparent to look at independently. I'd be happy to chat about this further in case you decide to go down this road. |
Thanks for the comments @aphearin. I'll take a look at your links.
At @dmargala's urging, I tried to write the fitting "guts" using as much pure-numpy as I could. For example, here's the objective function which is passed to with one notable (but not dominant) bottleneck being the construction of the emission-line model (a sum of Gaussians) for the set of parameters being optimized (ranging from 10-50 parameters)-- For the optimization, I was thinking of using one of the constrained optimization algorithms in https://jaxopt.github.io/stable/constrained.html, although I have no idea how to get JAX (and JAXlib) running at NERSC (whether in a Docker/shifter container or not; see, e.g., jax-ml/jax#6340). |
OK this looks pretty tractable to me actually. The usual things that need to be rewritten are control flow within for loops, which takes a little fiddling but is usually not too bad. One blocker to be aware of are while loops, which are a no-go for taking gradients with JAX (while loops are actually supported by JAX, but I think reverse-mode differentiation through them is unsupported and I don't know whether support is coming anytime soon). I think maybe rewriting the _trapz_rebin kernel in JAX might be first focal point, which doesn't look to bad at first glance. |
|
I'm not seeing anything in the _trapz_rebin kernel that looks like a blocker |
@moustakas in case its helpful I recently wrote a GPU-optimized version of |
Hello, reading your comments about Jax-accelerated spectral synthesis : do you know of Exojax and is it suited for your application ? |
The code refactoring in #92, #95, and #96 added much more robust and extensive fitting capabilities (even after removing the dependence on the very slow
astropy.modeling
classes) at the expense of speed. For example, for the non-linear least-squares fitting which is used to model the emission lines, I switched from using the Levenberg-Marquardt (lm
) algorithm in scipy.optimize.least_square to the Trust Region Reflective (trf
) algorithm, which incorporates bounds into the minimization procedure and is significantly more robust---but slower!The bottleneck in
trf
is the numerical differentiation step (see, e.g., https://stackoverflow.com/questions/68507176/faster-scipy-optimizations), and so I think the code is a great candidate to be ported to GPUs, where algorithmic differentiation using, e.g., Jax can be factors of many faster.In addition, for the stellar continuum fitting (including inferring the velocity dispersion), the main algorithm is the non-negative least-squares fitting provided by scipy.optimize.nnls. Despite its Fortran-bindings,
nnls
is still a bottleneck; the other slow pieces are the resampling of the templates at the redshift of each object and the convolution with the resolution matrix (here, I'm using Redrock's trapz_rebin algorithm, which already usesnumba/jit
for speed---but still takes non-negligible time).@dmargala @marcelo-alvarez @craigwarner-ufastro @sbailey and others---I'd be grateful for any thoughts or insight for how to proceed.
In the meantime, here are some profiling results:
Logging into perlmutter:
I get the following profile: (Note: I wasn't sure how to launch
snakeviz
on perlmutter, so I copied thefastspec.prof
file to my laptop. Also, I'm ignoring the I/O at the moment because those steps are slow for a single object but should be a small fraction of the total time when fitting a full healpixel):The text was updated successfully, but these errors were encountered: