diff --git a/.gitignore b/.gitignore index 20abab0..e8ac431 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,3 @@ imgui.ini .DS_Store /.idea /log - -# Examples -/example/vector_retargeting/data/ -!example/vector_retargeting/data/human_hand_video.mp4 diff --git a/dex_retargeting/retargeting_config.py b/dex_retargeting/retargeting_config.py index 55607a9..fd6f7fd 100644 --- a/dex_retargeting/retargeting_config.py +++ b/dex_retargeting/retargeting_config.py @@ -6,11 +6,12 @@ import numpy as np import yaml +from dex_retargeting import yourdfpy as urdf +from dex_retargeting.kinematics_adaptor import MimicJointKinematicAdaptor from dex_retargeting.optimizer_utils import LPFilter from dex_retargeting.robot_wrapper import RobotWrapper from dex_retargeting.seq_retarget import SeqRetargeting -from dex_retargeting import yourdfpy as urdf -from dex_retargeting.kinematics_adaptor import MimicJointKinematicAdaptor +from dex_retargeting.yourdfpy import DUMMY_JOINT_NAMES @dataclass @@ -136,7 +137,9 @@ def build(self) -> SeqRetargeting: import tempfile # Process the URDF with yourdfpy to better find file path - robot_urdf = urdf.URDF.load(self.urdf_path, build_scene_graph=False) + robot_urdf = urdf.URDF.load( + self.urdf_path, add_dummy_free_joints=self.add_dummy_free_joint, build_scene_graph=False + ) urdf_name = self.urdf_path.split("/")[-1] temp_dir = tempfile.mkdtemp(prefix="dex_retargeting-") temp_path = f"{temp_dir}/{urdf_name}" @@ -144,7 +147,12 @@ def build(self) -> SeqRetargeting: # Load pinocchio model robot = RobotWrapper(temp_path) + + # Add 6D dummy joint to target joint names so that it will also be optimized + if self.add_dummy_free_joint and self.target_joint_names is not None: + self.target_joint_names = DUMMY_JOINT_NAMES + self.target_joint_names joint_names = self.target_joint_names if self.target_joint_names is not None else robot.dof_joint_names + if self.type == "position": optimizer = PositionOptimizer( robot, @@ -210,7 +218,7 @@ def build(self) -> SeqRetargeting: return retargeting -def get_retargeting_config(config_path) -> RetargetingConfig: +def get_retargeting_config(config_path: Union[str, Path]) -> RetargetingConfig: config = RetargetingConfig.load_from_file(config_path) return config @@ -228,12 +236,3 @@ def parse_mimic_joint(robot_urdf: urdf.URDF) -> Tuple[bool, List[str], List[str] offsets.append(joint.mimic.offset) return len(mimic_joint_names) > 0, source_joint_names, mimic_joint_names, multipliers, offsets - - -if __name__ == "__main__": - # Path below is relative to this file - - test_config = get_retargeting_config(str(Path(__file__).parent / "configs/allegro_hand.yml")) - print(test_config) - opt = test_config.build() - print(opt.optimizer.target_link_human_indices) diff --git a/dex_retargeting/seq_retarget.py b/dex_retargeting/seq_retarget.py index c806b41..5c28133 100644 --- a/dex_retargeting/seq_retarget.py +++ b/dex_retargeting/seq_retarget.py @@ -52,8 +52,6 @@ def warm_start(self, wrist_pos: np.ndarray, wrist_orientation: np.ndarray, globa wrist_orientation: orientation of the hand orientation, typically from human hand pose in MANO convention global_rot: - Returns: - """ # This function can only be used when the first joints of robot are free joints if len(wrist_pos) != 3: diff --git a/dex_retargeting/yourdfpy.py b/dex_retargeting/yourdfpy.py index 405fc5a..ce8e287 100644 --- a/dex_retargeting/yourdfpy.py +++ b/dex_retargeting/yourdfpy.py @@ -893,7 +893,7 @@ def _validate_required_attribute(self, attribute, error_msg, allowed_values=None self._errors.append(URDFAttributeValueError(error_msg)) @staticmethod - def load(fname_or_file, **kwargs): + def load(fname_or_file, add_dummy_free_joints=False, **kwargs): """Load URDF file from filename or file object. Args: @@ -947,7 +947,9 @@ def load(fname_or_file, **kwargs): etree.strip_tags(xml_root, etree.Comment) etree.cleanup_namespaces(xml_root) - return URDF(robot=URDF._parse_robot(xml_element=xml_root), **kwargs) + return URDF( + robot=URDF._parse_robot(xml_element=xml_root, add_dummy_free_joints=add_dummy_free_joints), **kwargs + ) def contains(self, key, value, element=None) -> bool: """Checks recursively whether the URDF tree contains the provided key-value pair. @@ -2060,7 +2062,7 @@ def _write_joint(self, xml_parent, joint): self._write_dynamics(xml_element, joint.dynamics) @staticmethod - def _parse_robot(xml_element): + def _parse_robot(xml_element, add_dummy_free_joints=False): robot = Robot(name=xml_element.attrib["name"]) for l in xml_element.findall("link"): @@ -2069,6 +2071,19 @@ def _parse_robot(xml_element): robot.joints.append(URDF._parse_joint(j)) for m in xml_element.findall("material"): robot.materials.append(URDF._parse_material(m)) + + if add_dummy_free_joints: + # Determine root link + link_names = [l.name for l in robot.links] + for j in robot.joints: + link_names.remove(j.child) + + if len(link_names) == 0: + raise RuntimeError(f"No root link found for robot.") + + root_link_name = link_names[0] + _add_dummy_joints(robot, root_link_name) + return robot def _validate_robot(self, robot): @@ -2176,3 +2191,47 @@ def get_link_global_transform(self, link_name): node = anytree.search.findall_by_attr(self.tree_root, link_name)[0] return node.global_pose + + +def _add_dummy_joints(robot: Robot, root_link_name: str): + # Prepare link and joint properties + translation_range = (-5, 5) + rotation_range = (-2 * np.pi, 2 * np.pi) + joint_types = ["prismatic"] * 3 + ["revolute"] * 3 + joint_limit = [translation_range] * 3 + [rotation_range] * 3 + joint_name = DUMMY_JOINT_NAMES.copy() + link_name = [f"dummy_{name}_translation_link" for name in "xyz"] + [f"dummy_{name}_rotation_link" for name in "xyz"] + + links = [] + joints = [] + + for i in range(6): + inertial = Inertial( + mass=0.01, inertia=np.array([[1e-4, 0, 0], [0, 1e-4, 0], [0, 0, 1e-4]]), origin=np.identity(4) + ) + link = Link(name=link_name[i], inertial=inertial) + links.append(link) + + joint_axis = np.zeros(3, dtype=int) + joint_axis[i % 3] = 1 + limit = Limit(lower=joint_limit[i][0], upper=joint_limit[i][1], velocity=3.14, effort=10) + + child_name = link_name[i + 1] if i < 5 else root_link_name + joint = Joint( + name=joint_name[i], + type=joint_types[i], + parent=link_name[i], + child=child_name, + origin=np.identity(4), + axis=joint_axis, + limit=limit, + ) + joints.append(joint) + + robot.joints = joints + robot.joints + robot.links = links + robot.links + + +DUMMY_JOINT_NAMES = [f"dummy_{name}_translation_joint" for name in "xyz"] + [ + f"dummy_{name}_rotation_joint" for name in "xyz" +] diff --git a/example/vector_retargeting/.gitignore b/example/vector_retargeting/.gitignore new file mode 100644 index 0000000..7b753d3 --- /dev/null +++ b/example/vector_retargeting/.gitignore @@ -0,0 +1,3 @@ +# Examples +/example/vector_retargeting/data/ +!example/vector_retargeting/data/human_hand_video.mp4 diff --git a/tests/test_retargeting_config.py b/tests/test_retargeting_config.py index 5355a7f..dac280f 100644 --- a/tests/test_retargeting_config.py +++ b/tests/test_retargeting_config.py @@ -60,8 +60,6 @@ def test_dict_config_parsing(self): target_link_human_indices: [ 4, 8, 12, 16, 20 ] - # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency - # 1 means no filter while 0 means not moving low_pass_alpha: 1 """ cfg_dict = yaml.safe_load(cfg_str) @@ -75,29 +73,24 @@ def test_multi_dict_config_parsing(self): urdf_path: allegro_hand/allegro_hand_right.urdf wrist_link_name: "wrist" - # Target refers to the retargeting target, which is the robot hand target_joint_names: null target_origin_link_names: [ "wrist", "wrist", "wrist", "wrist" ] target_task_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ] scaling_factor: 1.6 - # Source refers to the retargeting input, which usually corresponds to the human hand # The joint indices of human hand joint which corresponds to each link in the target_link_names target_link_human_indices: [ [ 0, 0, 0, 0 ], [ 4, 8, 12, 16 ] ] - # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency low_pass_alpha: 0.2 - + - type: DexPilot urdf_path: leap_hand/leap_hand_right.urdf wrist_link_name: "base" - # Target refers to the retargeting target, which is the robot hand target_joint_names: null finger_tip_link_names: [ "thumb_tip_head", "index_tip_head", "middle_tip_head", "ring_tip_head" ] scaling_factor: 1.6 - # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency low_pass_alpha: 0.2 """ cfg_dict_list = yaml.safe_load(cfg_str) @@ -107,3 +100,24 @@ def test_multi_dict_config_parsing(self): retargeting = config.build() retargetings.append(retargeting) assert isinstance(retargeting, SeqRetargeting) + + @pytest.mark.parametrize("config_path", POSITION_CONFIG_DICT.values()) + def test_add_dummy_joint(self, config_path): + config_path = self.config_dir / config_path + override = {"add_dummy_free_joint": False} + config = RetargetingConfig.load_from_file(config_path, override) + retargeting = config.build() + robot = retargeting.optimizer.robot + original_robot_dof = robot.dof + original_active_dof = len(retargeting.optimizer.target_joint_names) + + override = {"add_dummy_free_joint": True} + config = RetargetingConfig.load_from_file(config_path, override) + retargeting = config.build() + robot = retargeting.optimizer.robot + + assert robot.dof == original_robot_dof + 6 + assert retargeting.joint_limits.shape == (original_active_dof + 6, 2) + dummy_joint_names = robot.dof_joint_names[:6] + for i in range(6): + assert "dummy" in dummy_joint_names[i]