Skip to content

Commit

Permalink
blacked
Browse files Browse the repository at this point in the history
  • Loading branch information
EveCharbie committed Jul 23, 2023
1 parent 67e29ff commit 691a73a
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 51 deletions.
3 changes: 1 addition & 2 deletions bioptim/dynamics/configure_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,7 +1450,6 @@ def configure_stochastic_cov_implicit(ocp, nlp, n_noised_states: int):
skip_plot=True,
)


@staticmethod
def configure_stochastic_cholesky_cov(ocp, nlp, n_noised_states: int):
"""
Expand All @@ -1468,7 +1467,7 @@ def configure_stochastic_cholesky_cov(ocp, nlp, n_noised_states: int):

name_cov = []
for nb_1, name_1 in enumerate([f"X_{i}" for i in range(n_noised_states)]):
for name_2 in [f"X_{i}" for i in range(nb_1+1)]:
for name_2 in [f"X_{i}" for i in range(nb_1 + 1)]:
name_cov += [name_1 + "_&_" + name_2]
nlp.variable_mappings[name] = BiMapping(list(range(len(name_cov))), list(range(len(name_cov))))
ConfigureProblem.configure_new_variable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,11 @@ def reach_target_consistantly(controllers: list[PenaltyController]) -> cas.MX:
val = fun(
controllers[-1].states["q"].cx_start,
controllers[-1].states["qdot"].cx_start,
(controllers[-1].stochastic_variables["cholesky_cov"].cx_start if "cholesky_cov" in controllers[-1].stochastic_variables.keys() else controllers[-1].stochastic_variables["cov"].cx_start),
(
controllers[-1].stochastic_variables["cholesky_cov"].cx_start
if "cholesky_cov" in controllers[-1].stochastic_variables.keys()
else controllers[-1].stochastic_variables["cov"].cx_start
),
)
# Since the stochastic variables are defined with ns+1, the cx_start actually refers to the last node (when using node=Node.END)

Expand Down Expand Up @@ -577,7 +581,7 @@ def prepare_socp(
else:
n_cholesky_cov = 0
for i in range(n_states):
for j in range(i+1):
for j in range(i + 1):
n_cholesky_cov += 1
n_stochastic += n_cholesky_cov # + cholesky_cov(10)
stochastic_init = np.zeros((n_stochastic, n_shooting + 1))
Expand Down Expand Up @@ -640,7 +644,7 @@ def prepare_socp(
cov_init = cas.DM_eye(n_states) * np.array([1e-4, 1e-4, 1e-7, 1e-7])
idx = 0
for i in range(n_states):
for j in range(i+1):
for j in range(i + 1):
stochastic_init[idx, 0] = cov_init[i, j]
s_init.add(
"cholesky_cov",
Expand Down Expand Up @@ -723,7 +727,7 @@ def main():
sensory_noise_magnitude=sensory_noise_magnitude,
problem_type=problem_type,
force_field_magnitude=force_field_magnitude,
cholesky_flag=cholesky_flag
cholesky_flag=cholesky_flag,
)

sol_socp = socp.solve(solver)
Expand Down Expand Up @@ -759,7 +763,9 @@ def main():
}

# --- Save the results --- #
with open(f"leuvenarm_torque_driven_socp_{problem_type}_forcefield{force_field_magnitude}_{cholesky_flag}.pkl", "wb") as file:
with open(

Check warning on line 766 in bioptim/examples/stochastic_optimal_control/arm_reaching_torque_driven_implicit.py

View check run for this annotation

Codecov / codecov/patch

bioptim/examples/stochastic_optimal_control/arm_reaching_torque_driven_implicit.py#L766

Added line #L766 was not covered by tests
f"leuvenarm_torque_driven_socp_{problem_type}_forcefield{force_field_magnitude}_{cholesky_flag}.pkl", "wb"
) as file:
pickle.dump(data, file)

# --- Visualize the results --- #
Expand Down
2 changes: 1 addition & 1 deletion bioptim/optimization/optimization_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def reshape_to_cholesky_matrix(variable, shape_0, node: Node, key: str):
matrix = MX.zeros(shape_0, shape_0)
i = 0
for s0 in range(shape_0):
for s1 in range(s0+1):
for s1 in range(s0 + 1):
if node == Node.START:
matrix[s0, s1] = variable[key].cx_start[i]
elif node == Node.MID:
Expand Down
192 changes: 149 additions & 43 deletions tests/test_global_stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def test_arm_reaching_muscle_driven():
# TestUtils.simulate(sol) # TODO: charbie -> fix this
# for now, it does not match because the integration is done in the multinode_constraint


def test_arm_reaching_torque_driven_explicit():
from bioptim.examples.stochastic_optimal_control import arm_reaching_torque_driven_explicit as ocp_module

Expand Down Expand Up @@ -524,6 +525,7 @@ def test_arm_reaching_torque_driven_explicit():
),
)


@pytest.mark.parametrize("cholesky_flag", [True, False])
def test_arm_reaching_torque_driven_implicit(cholesky_flag):
from bioptim.examples.stochastic_optimal_control import arm_reaching_torque_driven_implicit as ocp_module
Expand Down Expand Up @@ -606,24 +608,37 @@ def test_arm_reaching_torque_driven_implicit(cholesky_flag):
np.testing.assert_almost_equal(qdot[:, -2], np.array([9.34909789e-11, 2.11489648e-10]))

np.testing.assert_almost_equal(tau[:, 0], np.array([0.74345001, -0.38482294]))
np.testing.assert_almost_equal(tau[:, -2], np.array([-0.69873141, 0.44427599]))
np.testing.assert_almost_equal(tau[:, -2], np.array([-0.69873141, 0.44427599]))

np.testing.assert_almost_equal(
k[:, 0],
np.array([0.01523928, 0.01556081, 0.03375243, 0.05246741, -0.00879659,
0.01632912, 0.00877083, -0.01418607]),
np.array(
[0.01523928, 0.01556081, 0.03375243, 0.05246741, -0.00879659, 0.01632912, 0.00877083, -0.01418607]
),
)
np.testing.assert_almost_equal(
ref[:, 0], np.array([2.81907783e-02, 2.84412560e-01, -3.84350362e-11, -6.31154841e-12])
ref[:, 0], np.array([2.81907783e-02, 2.84412560e-01, -3.84350362e-11, -6.31154841e-12])
)
np.testing.assert_almost_equal(
m[:, 0],
np.array(
[
1.11118843e+00, 4.33671754e-05, -1.24355084e-02, -1.92738667e-05,
7.71134318e-05, 1.11188416e+00, -1.41446562e-04, -1.23501869e-02,
-6.95870145e-03, -3.90308717e-03, 1.11919572e+00, 1.73466734e-03,
-6.94022101e-03, -6.95743720e-02, 1.27302256e-02, 1.11151678e+00,
1.11118843e00,
4.33671754e-05,
-1.24355084e-02,
-1.92738667e-05,
7.71134318e-05,
1.11188416e00,
-1.41446562e-04,
-1.23501869e-02,
-6.95870145e-03,
-3.90308717e-03,
1.11919572e00,
1.73466734e-03,
-6.94022101e-03,
-6.95743720e-02,
1.27302256e-02,
1.11151678e00,
]
),
)
Expand All @@ -632,10 +647,22 @@ def test_arm_reaching_torque_driven_implicit(cholesky_flag):
cov[:, -2],
np.array(
[
-8.80346012e-05, -4.69527095e-05, 8.35293213e-05, 1.56300610e-04,
-4.69527095e-05, -3.44615160e-05, 7.29566569e-05, 1.35527530e-04,
8.35293213e-05, 7.29566569e-05, -2.26287713e-04, -2.80104699e-04,
1.56300610e-04, 1.35527530e-04, -2.80104699e-04, -4.80293202e-04,
-8.80346012e-05,
-4.69527095e-05,
8.35293213e-05,
1.56300610e-04,
-4.69527095e-05,
-3.44615160e-05,
7.29566569e-05,
1.35527530e-04,
8.35293213e-05,
7.29566569e-05,
-2.26287713e-04,
-2.80104699e-04,
1.56300610e-04,
1.35527530e-04,
-2.80104699e-04,
-4.80293202e-04,
]
),
)
Expand All @@ -644,10 +671,22 @@ def test_arm_reaching_torque_driven_implicit(cholesky_flag):
a[:, 3],
np.array(
[
9.99999997e-01, -2.94678167e-09, -1.00000003e-01, -1.10867716e-10,
-2.79029383e-09, 9.99999997e-01, -2.20141327e-09, -1.00000000e-01,
3.02420237e-02, -4.20998617e-01, 9.87060279e-01, -3.08564981e-02,
-7.57260229e-03, 1.08904346e+00, 9.53610677e-02, 1.11920251e+00,
9.99999997e-01,
-2.94678167e-09,
-1.00000003e-01,
-1.10867716e-10,
-2.79029383e-09,
9.99999997e-01,
-2.20141327e-09,
-1.00000000e-01,
3.02420237e-02,
-4.20998617e-01,
9.87060279e-01,
-3.08564981e-02,
-7.57260229e-03,
1.08904346e00,
9.53610677e-02,
1.11920251e00,
]
),
)
Expand All @@ -656,12 +695,30 @@ def test_arm_reaching_torque_driven_implicit(cholesky_flag):
c[:, 3],
np.array(
[
-1.00026960e-12, 7.48271175e-12, 7.22298606e-12, -1.12880911e-11,
1.11223927e-02, -5.94119908e-03, -4.41155433e-13, 4.22693132e-12,
4.55705335e-12, -6.72512449e-12, 1.91518725e-02, -9.67304622e-03,
-1.34051958e+00, 1.51907793e+00, -4.45148469e-02, 1.50301525e-02,
-8.76853509e-02, 4.11969104e-02, 1.51907793e+00, -4.56171523e+00,
-1.48051682e-02, 6.70065631e-02, -9.60790421e-02, 4.58470601e-02,
-1.00026960e-12,
7.48271175e-12,
7.22298606e-12,
-1.12880911e-11,
1.11223927e-02,
-5.94119908e-03,
-4.41155433e-13,
4.22693132e-12,
4.55705335e-12,
-6.72512449e-12,
1.91518725e-02,
-9.67304622e-03,
-1.34051958e00,
1.51907793e00,
-4.45148469e-02,
1.50301525e-02,
-8.76853509e-02,
4.11969104e-02,
1.51907793e00,
-4.56171523e00,
-1.48051682e-02,
6.70065631e-02,
-9.60790421e-02,
4.58470601e-02,
]
),
)
Expand Down Expand Up @@ -690,25 +747,37 @@ def test_arm_reaching_torque_driven_implicit(cholesky_flag):
np.testing.assert_almost_equal(qdot[:, -2], np.array([1.14034274e-09, 1.77348396e-09]))

np.testing.assert_almost_equal(tau[:, 0], np.array([0.74341393, -0.38470965]))
np.testing.assert_almost_equal(tau[:, -2], np.array([-0.69875678, 0.44426507]))
np.testing.assert_almost_equal(tau[:, -2], np.array([-0.69875678, 0.44426507]))

np.testing.assert_almost_equal(
k[:, 0],
np.array(
[0.01531877, 0.01126498, 0.01593056, 0.01857115, -0.00125035,
-0.00515613, 0.00340021, -0.01075679]),
[0.01531877, 0.01126498, 0.01593056, 0.01857115, -0.00125035, -0.00515613, 0.00340021, -0.01075679]
),
)
np.testing.assert_almost_equal(
ref[:, 0], np.array([2.81907762e-02, 2.84412559e-01, -1.82246478e-10, -3.02336569e-12])
ref[:, 0], np.array([2.81907762e-02, 2.84412559e-01, -1.82246478e-10, -3.02336569e-12])
)
np.testing.assert_almost_equal(
m[:, 0],
np.array(
[
1.11111399e+00, 3.60727553e-05, -1.24942749e-02, -6.89880004e-05,
-1.55956208e-05, 1.11185104e+00, -2.11345774e-04, -1.23982943e-02,
-2.58769034e-04, -3.24653915e-03, 1.12448474e+00, 6.20892944e-03,
1.40361018e-03, -6.65935312e-02, 1.90211252e-02, 1.11584654e+00,
1.11111399e00,
3.60727553e-05,
-1.24942749e-02,
-6.89880004e-05,
-1.55956208e-05,
1.11185104e00,
-2.11345774e-04,
-1.23982943e-02,
-2.58769034e-04,
-3.24653915e-03,
1.12448474e00,
6.20892944e-03,
1.40361018e-03,
-6.65935312e-02,
1.90211252e-02,
1.11584654e00,
]
),
)
Expand All @@ -717,9 +786,16 @@ def test_arm_reaching_torque_driven_implicit(cholesky_flag):
cov[:, -2],
np.array(
[
-4.46821105e-03, -1.71731520e-03, -1.02009010e-02, -3.58196407e-03,
-6.50385303e-03, 9.57036181e-03, 2.93606642e-03, -1.82590044e-04,
8.51698871e-03, 9.33034990e-05,
-4.46821105e-03,
-1.71731520e-03,
-1.02009010e-02,
-3.58196407e-03,
-6.50385303e-03,
9.57036181e-03,
2.93606642e-03,
-1.82590044e-04,
8.51698871e-03,
9.33034990e-05,
]
),
)
Expand All @@ -728,10 +804,22 @@ def test_arm_reaching_torque_driven_implicit(cholesky_flag):
a[:, 3],
np.array(
[
1.00000000e+00, 1.08524580e-09, -9.99999991e-02, 2.72912724e-10,
-1.29617696e-10, 9.99999995e-01, -5.48136491e-09, -1.00000001e-01,
3.98959553e-02, -4.10112704e-01, 1.01332373e+00, -2.12383714e-02,
-7.85600590e-02, 1.06875322e+00, 2.12659225e-02, 1.09408356e+00,
1.00000000e00,
1.08524580e-09,
-9.99999991e-02,
2.72912724e-10,
-1.29617696e-10,
9.99999995e-01,
-5.48136491e-09,
-1.00000001e-01,
3.98959553e-02,
-4.10112704e-01,
1.01332373e00,
-2.12383714e-02,
-7.85600590e-02,
1.06875322e00,
2.12659225e-02,
1.09408356e00,
]
),
)
Expand All @@ -740,12 +828,30 @@ def test_arm_reaching_torque_driven_implicit(cholesky_flag):
c[:, 3],
np.array(
[
-9.28531424e-12, 2.10560432e-11, -8.74791141e-12, -1.84391377e-11,
-2.10526798e-04, 8.92312491e-04, 2.61170664e-11, -4.05508057e-11,
2.86516265e-11, 4.55817345e-11, -2.73807889e-02, 6.59798851e-02,
-1.34796595e+00, 1.37762629e+00, -2.57583298e-02, -3.30498283e-02,
-2.29074656e-02, 5.59309999e-02, 1.37762629e+00, -4.28727508e+00,
1.39574001e-02, 4.84712227e-02, 6.11340782e-04, -4.45068909e-03,
-9.28531424e-12,
2.10560432e-11,
-8.74791141e-12,
-1.84391377e-11,
-2.10526798e-04,
8.92312491e-04,
2.61170664e-11,
-4.05508057e-11,
2.86516265e-11,
4.55817345e-11,
-2.73807889e-02,
6.59798851e-02,
-1.34796595e00,
1.37762629e00,
-2.57583298e-02,
-3.30498283e-02,
-2.29074656e-02,
5.59309999e-02,
1.37762629e00,
-4.28727508e00,
1.39574001e-02,
4.84712227e-02,
6.11340782e-04,
-4.45068909e-03,
]
),
)

0 comments on commit 691a73a

Please sign in to comment.