Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fastspecfit at risk of becoming slowspecfit #98

Open
moustakas opened this issue Jan 26, 2023 · 8 comments
Open

fastspecfit at risk of becoming slowspecfit #98

moustakas opened this issue Jan 26, 2023 · 8 comments

Comments

@moustakas
Copy link
Member

The code refactoring in #92, #95, and #96 added much more robust and extensive fitting capabilities (even after removing the dependence on the very slow astropy.modeling classes) at the expense of speed. For example, for the non-linear least-squares fitting which is used to model the emission lines, I switched from using the Levenberg-Marquardt (lm) algorithm in scipy.optimize.least_square to the Trust Region Reflective (trf) algorithm, which incorporates bounds into the minimization procedure and is significantly more robust---but slower!

The bottleneck in trf is the numerical differentiation step (see, e.g., https://stackoverflow.com/questions/68507176/faster-scipy-optimizations), and so I think the code is a great candidate to be ported to GPUs, where algorithmic differentiation using, e.g., Jax can be factors of many faster.

In addition, for the stellar continuum fitting (including inferring the velocity dispersion), the main algorithm is the non-negative least-squares fitting provided by scipy.optimize.nnls. Despite its Fortran-bindings, nnls is still a bottleneck; the other slow pieces are the resampling of the templates at the redshift of each object and the convolution with the resolution matrix (here, I'm using Redrock's trapz_rebin algorithm, which already uses numba/jit for speed---but still takes non-negligible time).

@dmargala @marcelo-alvarez @craigwarner-ufastro @sbailey and others---I'd be grateful for any thoughts or insight for how to proceed.

In the meantime, here are some profiling results:

Logging into perlmutter:

source /global/cfs/cdirs/desi/software/desi_environment.sh 23.1
module load fastspecfit/2.0.0

python -m cProfile -o fastspec.prof /global/common/software/desi/perlmutter/desiconda/20230111-2.1.0/code/fastspecfit/2.0.0/bin/fastspec \
  $DESI_ROOT/spectro/redux/fuji/tiles/cumulative/80613/20210324/redrock-4-80613-thru20210324.fits \
  -o fastspec.fits --targetids 39633345008634465

I get the following profile: (Note: I wasn't sure how to launch snakeviz on perlmutter, so I copied the fastspec.prof file to my laptop. Also, I'm ignoring the I/O at the moment because those steps are slow for a single object but should be a small fraction of the total time when fitting a full healpixel):

Screenshot 2023-01-26 at 6 09 19 PM

Screenshot 2023-01-26 at 6 09 53 PM

@moustakas
Copy link
Member Author

Here are the results from fitting 50 objects:

python -m cProfile -o fastspec-50.prof /global/common/software/desi/perlmutter/desiconda/20230111-2.1.0/code/fastspecfit/2.0.0/bin/fastspec   \
  $DESI_ROOT/spectro/redux/fuji/tiles/cumulative/80613/20210324/redrock-4-80613-thru20210324.fits \
  --ntargets 50 -o fastspec-50.fits

As expected, the I/O time becomes negligible (and this is with perlmutter in a degraded state) and the time is entirely dominated by the line-fitting (foremost) and the continuum-fitting (secondarily). Within the continuum-fitting, nnls is actually pretty neglible compared to the trapezoidal rebinning (called by smooth_and_resample), the Gaussian-broadening (part of the velocity dispersion fitting), and the filter convolutions (used to synthesize photometry).

Screenshot 2023-01-27 at 6 27 40 AM

Zooming into the continuum-fitting:

Screenshot 2023-01-27 at 6 34 46 AM

@aphearin
Copy link

@moustakas kudos for shining fastspecfit down to the bone and eliminating all of the bottlenecks besides the actual fitting algorithm. For problems of modest dimension (Ndim<~50-100), the most robust and fastest fitter I've used that implements bounds (kind of by far) is the L-BFGS algorithm, but this algorithm leans hard on accurate gradients because it uses the Hessian to condition the gradient descent. Higher-order derivs can be slow to compute numerically, and it can get really tiresome to check their accuracy numerically, but they come with guaranteed machine-precision "for free" if the likelihood/cost function is implemented in an autodiff library like JAX. Speedup factors can be an order of magnitude or more when deploying algorithms based on higher-order gradients on modern tensorcore GPUs.

The level of pain it would require to reimplement the internals of fastspecfit into JAX depends entirely on the nitty gritty details of the algorithms you're using - sometimes this is rather straightforward, sometimes it's a bit of a lift, and sometimes it requires abandoning certain sophisticated iterative algorithms in favor of simpler non-adaptive ones. The DSPS paper has some in-the-weeds pointers for JAX implementations for the case of a traditional SPS approach to the problem. The DSPS library is open-source, but this is really more of a collection of kernels and not yet a well-documented library so I imagine it might not be very transparent to look at independently. I'd be happy to chat about this further in case you decide to go down this road.

@moustakas
Copy link
Member Author

Thanks for the comments @aphearin. I'll take a look at your links.

The level of pain it would require to reimplement the internals of fastspecfit into JAX depends entirely on the nitty gritty details of the algorithms you're using - sometimes this is rather straightforward, sometimes it's a bit of a lift, and sometimes it requires abandoning certain sophisticated iterative algorithms in favor of simpler non-adaptive ones.

At @dmargala's urging, I tried to write the fitting "guts" using as much pure-numpy as I could. For example, here's the objective function which is passed to scipy.optimize.least_squares
https://github.com/desihub/fastspecfit/blob/main/py/fastspecfit/emlines.py#L74-L113

with one notable (but not dominant) bottleneck being the construction of the emission-line model (a sum of Gaussians) for the set of parameters being optimized (ranging from 10-50 parameters)--
https://github.com/desihub/fastspecfit/blob/main/py/fastspecfit/emlines.py#L28-L72

For the optimization, I was thinking of using one of the constrained optimization algorithms in https://jaxopt.github.io/stable/constrained.html, although I have no idea how to get JAX (and JAXlib) running at NERSC (whether in a Docker/shifter container or not; see, e.g., jax-ml/jax#6340).

@aphearin
Copy link

OK this looks pretty tractable to me actually. The usual things that need to be rewritten are control flow within for loops, which takes a little fiddling but is usually not too bad. One blocker to be aware of are while loops, which are a no-go for taking gradients with JAX (while loops are actually supported by JAX, but I think reverse-mode differentiation through them is unsupported and I don't know whether support is coming anytime soon). I think maybe rewriting the _trapz_rebin kernel in JAX might be first focal point, which doesn't look to bad at first glance.

@moustakas
Copy link
Member Author

trapz_rebin is also one of the bottlenecks in Redrock, actually (from which I stole this code shamelessly but with permission from @sbailey), but I'm not sure if someone is already rewriting it away from numba.jit and into another GPU-optimized way.

@aphearin
Copy link

I'm not seeing anything in the _trapz_rebin kernel that looks like a blocker

@craigwarner-ufastro
Copy link

@moustakas in case its helpful I recently wrote a GPU-optimized version of trapz_rebin that has been merged into the main branch of Redrock now. If you are rebinning to many z then it is much faster.

@erwanp
Copy link

erwanp commented Aug 19, 2023

Hello, reading your comments about Jax-accelerated spectral synthesis : do you know of Exojax and is it suited for your application ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants