Fix JAX 0.4.28 regression in SciPy logsumexp
scale propagation.
#335
Job | Run time |
---|---|
31s | |
22s | |
53s |
logsumexp
scale propagation.
#335
Job | Run time |
---|---|
31s | |
22s | |
53s |