diff --git a/dex_retargeting/optimizer.py b/dex_retargeting/optimizer.py index fdc606e..f8d3eb7 100644 --- a/dex_retargeting/optimizer.py +++ b/dex_retargeting/optimizer.py @@ -314,21 +314,14 @@ def __init__( eta2=3e-2, scaling=1.0, ): - # if len(finger_tip_link_names) < 4 or len(finger_tip_link_names) > 5: - # raise ValueError(f"DexPilot optimizer can only be applied to hands with four or five fingers") + if len(finger_tip_link_names) < 2 or len(finger_tip_link_names) > 5: + raise ValueError( + f"DexPilot optimizer can only be applied to hands with 2 to 5 fingers, but got " + f"{len(finger_tip_link_names)} fingers." + ) self.num_fingers = len(finger_tip_link_names) - if self.num_fingers == 2: # For gripper - origin_link_index = [2, 0, 0] - task_link_index = [1, 1, 2] - elif self.num_fingers == 4: - origin_link_index = [2, 3, 4, 3, 4, 4, 0, 0, 0, 0] - task_link_index = [1, 1, 1, 2, 2, 3, 1, 2, 3, 4] - elif self.num_fingers == 5: - origin_link_index = [2, 3, 4, 5, 3, 4, 5, 4, 5, 5, 0, 0, 0, 0, 0] - task_link_index = [1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 1, 2, 3, 4, 5] - else: - raise NotImplementedError(f"Unsupported number of fingers: {self.num_fingers}") + origin_link_index, task_link_index = self.generate_link_indices(self.num_fingers) if target_link_human_indices is None: target_link_human_indices = (np.stack([origin_link_index, task_link_index], axis=0) * 4).astype(int) @@ -363,23 +356,55 @@ def __init__( self.opt.set_ftol_abs(1e-6) # DexPilot cache - if self.num_fingers == 2: - self.projected = np.zeros(1, dtype=bool) - self.s2_project_index_origin = np.array([], dtype=int) - self.s2_project_index_task = np.array([], dtype=int) - self.projected_dist = np.array([eta1] * 1) - elif self.num_fingers == 4: - self.projected = np.zeros(6, dtype=bool) - self.s2_project_index_origin = np.array([1, 2, 2], dtype=int) - self.s2_project_index_task = np.array([0, 0, 1], dtype=int) - self.projected_dist = np.array([eta1] * 3 + [eta2] * 3) - elif self.num_fingers == 5: - self.projected = np.zeros(10, dtype=bool) - self.s2_project_index_origin = np.array([1, 2, 3, 2, 3, 3], dtype=int) - self.s2_project_index_task = np.array([0, 0, 0, 1, 1, 2], dtype=int) - self.projected_dist = np.array([eta1] * 4 + [eta2] * 6) - else: - raise NotImplementedError(f"Unsupported number of fingers: {self.num_fingers}") + self.projected, self.s2_project_index_origin, self.s2_project_index_task, self.projected_dist = ( + self.set_dexpilot_cache(self.num_fingers, eta1, eta2) + ) + + @staticmethod + def generate_link_indices(num_fingers): + """ + Example: + >>> generate_link_indices(4) + ([2, 3, 4, 3, 4, 4, 0, 0, 0, 0], [1, 1, 1, 2, 2, 3, 1, 2, 3, 4]) + """ + origin_link_index = [] + task_link_index = [] + + # Add indices for connections between fingers + for i in range(1, num_fingers): + for j in range(i + 1, num_fingers + 1): + origin_link_index.append(j) + task_link_index.append(i) + + # Add indices for connections to the base (0) + for i in range(1, num_fingers + 1): + origin_link_index.append(0) + task_link_index.append(i) + + return origin_link_index, task_link_index + + @staticmethod + def set_dexpilot_cache(num_fingers, eta1, eta2): + """ + Example: + >>> set_dexpilot_cache(4, 0.1, 0.2) + (array([False, False, False, False, False, False]), + [1, 2, 2], + [0, 0, 1], + array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])) + """ + projected = np.zeros(num_fingers * (num_fingers - 1) // 2, dtype=bool) + + s2_project_index_origin = [] + s2_project_index_task = [] + for i in range(0, num_fingers - 2): + for j in range(i + 1, num_fingers - 1): + s2_project_index_origin.append(j) + s2_project_index_task.append(i) + + projected_dist = np.array([eta1] * (num_fingers - 1) + [eta2] * ((num_fingers - 1) * (num_fingers - 2) // 2)) + + return projected, s2_project_index_origin, s2_project_index_task, projected_dist def get_objective_function(self, target_vector: np.ndarray, fixed_qpos: np.ndarray, last_qpos: np.ndarray): qpos = np.zeros(self.num_joints) diff --git a/example/vector_retargeting/capture_webcam.py b/example/vector_retargeting/capture_webcam.py index e3de15d..ac09c4b 100644 --- a/example/vector_retargeting/capture_webcam.py +++ b/example/vector_retargeting/capture_webcam.py @@ -30,7 +30,7 @@ def main(video_path: str, video_capture_device: Union[str, int] = 0): if cv2.waitKey(1) & 0xFF == 27: break - print('Recording finished') + print("Recording finished") cap.release() writer.release() cv2.destroyAllWindows() diff --git a/setup.py b/setup.py index 910a505..8e5a66d 100644 --- a/setup.py +++ b/setup.py @@ -107,7 +107,7 @@ def setup_package(): python_requires=">=3.7,<3.11", zip_safe=True, include_package_data=True, - package_data={'dex_retargeting': ['configs/**']}, + package_data={"dex_retargeting": ["configs/**"]}, install_requires=core_requirements, extras_require={ "dev": dev_requirements,