Skip to content

Commit

Permalink
Added tests dP, dA, dl, du.
Browse files Browse the repository at this point in the history
  • Loading branch information
AmitSolomonPrinceton committed Jul 5, 2024
1 parent dc8c56a commit 6568e68
Showing 1 changed file with 117 additions and 2 deletions.
119 changes: 117 additions & 2 deletions src/osqp/tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def get_grads_torch(P, q, A, l, u, true_x, algebra, solver_type):
grads = [x.grad.data.squeeze(0).cpu().numpy() for x in [P_torch, q_torch, A_torch, l_torch, u_torch]]
return grads


def test_dl_dp(algebra, solver_type, atol, rtol, decimal_tol):
def test_dl_dq(algebra, solver_type, atol, rtol, decimal_tol):
n, m = 5, 5

model = osqp.OSQP(algebra=algebra)
Expand Down Expand Up @@ -106,3 +105,119 @@ def f(q):
print('dq_fd: ', np.round(dq_fd, decimals=4))
print('dq: ', np.round(dq, decimals=4))
npt.assert_allclose(dq_fd, dq, rtol=RTOL, atol=ATOL)

def test_dl_dP(algebra, solver_type, atol, rtol, decimal_tol):
n, m = 5, 5

model = osqp.OSQP(algebra=algebra)
if not model.has_capability('OSQP_CAPABILITY_DERIVATIVES'):
pytest.skip('No derivatives capability')

[P, q, A, l, u, true_x], [dP, dq, dA, dl, du] = get_grads(
n=n,
m=m,
P_scale=100.0,
A_scale=100.0,
algebra=algebra,
solver_type=solver_type,
)

def f(P):
model.setup(P, q, A, l, u, solver_type=solver_type, verbose=False)
res = model.solve()
x_hat = res.x

return 0.5 * np.sum(np.square(x_hat - true_x))

dP_fd = approx_fprime(P, f)
if verbose:
print('dP_fd: ', np.round(dP_fd, decimals=4))
print('dP: ', np.round(dP, decimals=4))
npt.assert_allclose(dP_fd, dP, rtol=RTOL, atol=ATOL)

def test_dl_dA(algebra, solver_type, atol, rtol, decimal_tol):
n, m = 5, 5

model = osqp.OSQP(algebra=algebra)
if not model.has_capability('OSQP_CAPABILITY_DERIVATIVES'):
pytest.skip('No derivatives capability')

[P, q, A, l, u, true_x], [dP, dq, dA, dl, du] = get_grads(
n=n,
m=m,
P_scale=100.0,
A_scale=100.0,
algebra=algebra,
solver_type=solver_type,
)

def f(A):
model.setup(P, q, A, l, u, solver_type=solver_type, verbose=False)
res = model.solve()
x_hat = res.x

return 0.5 * np.sum(np.square(x_hat - true_x))

dA_fd = approx_fprime(A, f)
if verbose:
print('dA_fd: ', np.round(dA_fd, decimals=4))
print('dA: ', np.round(dA, decimals=4))
npt.assert_allclose(dA_fd, dA, rtol=RTOL, atol=ATOL)

def test_dl_dl(algebra, solver_type, atol, rtol, decimal_tol):
n, m = 5, 5

model = osqp.OSQP(algebra=algebra)
if not model.has_capability('OSQP_CAPABILITY_DERIVATIVES'):
pytest.skip('No derivatives capability')

[P, q, A, l, u, true_x], [dP, dq, dA, dl, du] = get_grads(
n=n,
m=m,
P_scale=100.0,
A_scale=100.0,
algebra=algebra,
solver_type=solver_type,
)

def f(l):
model.setup(P, q, A, l, u, solver_type=solver_type, verbose=False)
res = model.solve()
x_hat = res.x

return 0.5 * np.sum(np.square(x_hat - true_x))

dl_fd = approx_fprime(l, f)
if verbose:
print('dl_fd: ', np.round(dl_fd, decimals=4))
print('dl: ', np.round(dl, decimals=4))
npt.assert_allclose(dl_fd, dl, rtol=RTOL, atol=ATOL)

def test_dl_du(algebra, solver_type, atol, rtol, decimal_tol):
n, m = 5, 5

model = osqp.OSQP(algebra=algebra)
if not model.has_capability('OSQP_CAPABILITY_DERIVATIVES'):
pytest.skip('No derivatives capability')

[P, q, A, l, u, true_x], [dP, dq, dA, dl, du] = get_grads(
n=n,
m=m,
P_scale=100.0,
A_scale=100.0,
algebra=algebra,
solver_type=solver_type,
)

def f(u):
model.setup(P, q, A, l, u, solver_type=solver_type, verbose=False)
res = model.solve()
x_hat = res.x

return 0.5 * np.sum(np.square(x_hat - true_x))

du_fd = approx_fprime(u, f)
if verbose:
print('du_fd: ', np.round(du_fd, decimals=4))
print('du: ', np.round(du, decimals=4))
npt.assert_allclose(du_fd, du, rtol=RTOL, atol=ATOL)

0 comments on commit 6568e68

Please sign in to comment.