Skip to content

Commit

Permalink
[update] update DexPilot teleop for hands with 2-5 fingers
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingry committed Jul 30, 2024
1 parent b615da9 commit 619030d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 32 deletions.
85 changes: 55 additions & 30 deletions dex_retargeting/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,21 +314,14 @@ def __init__(
eta2=3e-2,
scaling=1.0,
):
# if len(finger_tip_link_names) < 4 or len(finger_tip_link_names) > 5:
# raise ValueError(f"DexPilot optimizer can only be applied to hands with four or five fingers")
if len(finger_tip_link_names) < 2 or len(finger_tip_link_names) > 5:
raise ValueError(
f"DexPilot optimizer can only be applied to hands with 2 to 5 fingers, but got "
f"{len(finger_tip_link_names)} fingers."
)
self.num_fingers = len(finger_tip_link_names)

if self.num_fingers == 2: # For gripper
origin_link_index = [2, 0, 0]
task_link_index = [1, 1, 2]
elif self.num_fingers == 4:
origin_link_index = [2, 3, 4, 3, 4, 4, 0, 0, 0, 0]
task_link_index = [1, 1, 1, 2, 2, 3, 1, 2, 3, 4]
elif self.num_fingers == 5:
origin_link_index = [2, 3, 4, 5, 3, 4, 5, 4, 5, 5, 0, 0, 0, 0, 0]
task_link_index = [1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 1, 2, 3, 4, 5]
else:
raise NotImplementedError(f"Unsupported number of fingers: {self.num_fingers}")
origin_link_index, task_link_index = self.generate_link_indices(self.num_fingers)

if target_link_human_indices is None:
target_link_human_indices = (np.stack([origin_link_index, task_link_index], axis=0) * 4).astype(int)
Expand Down Expand Up @@ -363,23 +356,55 @@ def __init__(
self.opt.set_ftol_abs(1e-6)

# DexPilot cache
if self.num_fingers == 2:
self.projected = np.zeros(1, dtype=bool)
self.s2_project_index_origin = np.array([], dtype=int)
self.s2_project_index_task = np.array([], dtype=int)
self.projected_dist = np.array([eta1] * 1)
elif self.num_fingers == 4:
self.projected = np.zeros(6, dtype=bool)
self.s2_project_index_origin = np.array([1, 2, 2], dtype=int)
self.s2_project_index_task = np.array([0, 0, 1], dtype=int)
self.projected_dist = np.array([eta1] * 3 + [eta2] * 3)
elif self.num_fingers == 5:
self.projected = np.zeros(10, dtype=bool)
self.s2_project_index_origin = np.array([1, 2, 3, 2, 3, 3], dtype=int)
self.s2_project_index_task = np.array([0, 0, 0, 1, 1, 2], dtype=int)
self.projected_dist = np.array([eta1] * 4 + [eta2] * 6)
else:
raise NotImplementedError(f"Unsupported number of fingers: {self.num_fingers}")
self.projected, self.s2_project_index_origin, self.s2_project_index_task, self.projected_dist = (
self.set_dexpilot_cache(self.num_fingers, eta1, eta2)
)

@staticmethod
def generate_link_indices(num_fingers):
"""
Example:
>>> generate_link_indices(4)
([2, 3, 4, 3, 4, 4, 0, 0, 0, 0], [1, 1, 1, 2, 2, 3, 1, 2, 3, 4])
"""
origin_link_index = []
task_link_index = []

# Add indices for connections between fingers
for i in range(1, num_fingers):
for j in range(i + 1, num_fingers + 1):
origin_link_index.append(j)
task_link_index.append(i)

# Add indices for connections to the base (0)
for i in range(1, num_fingers + 1):
origin_link_index.append(0)
task_link_index.append(i)

return origin_link_index, task_link_index

@staticmethod
def set_dexpilot_cache(num_fingers, eta1, eta2):
"""
Example:
>>> set_dexpilot_cache(4, 0.1, 0.2)
(array([False, False, False, False, False, False]),
[1, 2, 2],
[0, 0, 1],
array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]))
"""
projected = np.zeros(num_fingers * (num_fingers - 1) // 2, dtype=bool)

s2_project_index_origin = []
s2_project_index_task = []
for i in range(0, num_fingers - 2):
for j in range(i + 1, num_fingers - 1):
s2_project_index_origin.append(j)
s2_project_index_task.append(i)

projected_dist = np.array([eta1] * (num_fingers - 1) + [eta2] * ((num_fingers - 1) * (num_fingers - 2) // 2))

return projected, s2_project_index_origin, s2_project_index_task, projected_dist

def get_objective_function(self, target_vector: np.ndarray, fixed_qpos: np.ndarray, last_qpos: np.ndarray):
qpos = np.zeros(self.num_joints)
Expand Down
2 changes: 1 addition & 1 deletion example/vector_retargeting/capture_webcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def main(video_path: str, video_capture_device: Union[str, int] = 0):
if cv2.waitKey(1) & 0xFF == 27:
break

print('Recording finished')
print("Recording finished")
cap.release()
writer.release()
cv2.destroyAllWindows()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def setup_package():
python_requires=">=3.7,<3.11",
zip_safe=True,
include_package_data=True,
package_data={'dex_retargeting': ['configs/**']},
package_data={"dex_retargeting": ["configs/**"]},
install_requires=core_requirements,
extras_require={
"dev": dev_requirements,
Expand Down

0 comments on commit 619030d

Please sign in to comment.