You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
So, I have tried to implement this myself using Givens rotations. But due to N*M loops it is very slow, which makes it useless since the main goal of this function is to calculate QR faster than otherwise calling jax.scipy.linalg.qr. I guess this would be useful for people using jax for trust region Newton method optimization. I thought of using Householder reflections but I couldn't implement it under jax.jit. For reference, this is implemented in scipyhere and calls some LAPACK routines in a for loop. I hope this can be implemented in the same way some other matrix decomposition routines are implemented in JAX.
I naively tried to implement this myself purely in Python like this,
def_givens_jax(a, b):
b_zero=abs(b) ==0a_lt_b=abs(a) <abs(b)
t=-jnp.where(a_lt_b, a, b) /jnp.where(a_lt_b, b, a)
r=rsqrt(1+abs(t) **2).astype(t.dtype)
cs=jnp.where(b_zero, 1, jnp.where(a_lt_b, r*t, r))
sn=jnp.where(b_zero, 0, jnp.where(a_lt_b, r, r*t))
G2=jnp.array([[cs, -sn], [sn, cs]])
returnG2.astype(float)
@jax.jitdefupdate_qr_jax(A, w, q, r):
"""Update QR factorization with a diagonal matrix w at the bottom."""m, n=A.shapeQ=jnp.eye(m+n)
Q=Q.at[:m, :m].set(q)
R=jnp.vstack([r, w])
defbody_inner(i, jQR):
j, Q, R=jQRi=m+j-ia, b=R[i-1, j], R[i, j]
G2=_givens_jax(a, b)
R=R.at[jnp.array([i-1, i])].set(G2 @ R[jnp.array([i-1, i])])
Q=Q.at[:, jnp.array([i-1, i])].set(Q[:, jnp.array([i-1, i])] @ G2.T)
returnj, Q, Rdefbody(j, QR):
Q, R=QRj, Q, R=fori_loop(0, m, body_inner, (j, Q, R))
returnQ, RQ, R=fori_loop(0, n, body, (Q, R))
R=jnp.where(jnp.abs(R) <1e-10, 0, R)
returnQ, R
Note: I also tried economic mode QR to reduce matrix size, but this is still slow.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
For my application, I need to take QR of
where I already have
So, I have tried to implement this myself using Givens rotations. But due to
N*M
loops it is very slow, which makes it useless since the main goal of this function is to calculate QR faster than otherwise callingjax.scipy.linalg.qr
. I guess this would be useful for people usingjax
for trust region Newton method optimization. I thought of using Householder reflections but I couldn't implement it underjax.jit
. For reference, this is implemented inscipy
here and calls some LAPACK routines in a for loop. I hope this can be implemented in the same way some other matrix decomposition routines are implemented in JAX.I naively tried to implement this myself purely in Python like this,
Note: I also tried
economic
mode QR to reduce matrix size, but this is still slow.Beta Was this translation helpful? Give feedback.
All reactions