From 947b66ead90d200ec5f7f98420c7e8596aca3f9a Mon Sep 17 00:00:00 2001 From: yzqin Date: Thu, 21 Mar 2024 14:56:24 -0700 Subject: [PATCH] [update] update dict config parse --- dex_retargeting/retargeting_config.py | 21 ++++++--- tests/test_retargeting_config.py | 66 ++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 8 deletions(-) diff --git a/dex_retargeting/retargeting_config.py b/dex_retargeting/retargeting_config.py index 076df22..2010dd0 100644 --- a/dex_retargeting/retargeting_config.py +++ b/dex_retargeting/retargeting_config.py @@ -1,7 +1,7 @@ import sapien.core as sapien from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Any from typing import Union import numpy as np @@ -45,6 +45,9 @@ class RetargetingConfig: normal_delta: float = 4e-3 huber_delta: float = 2e-2 + # Constraint parameters + constraint_map: Optional[Dict[str, np.ndarray]] = None + # Joint limit tag has_joint_limits: bool = True @@ -110,12 +113,16 @@ def load_from_file(cls, config_path: Union[str, Path], override: Optional[Dict] with path.open("r") as f: yaml_config = yaml.load(f, Loader=yaml.FullLoader) cfg = yaml_config["retargeting"] - if "target_link_human_indices" in cfg: - cfg["target_link_human_indices"] = np.array(cfg["target_link_human_indices"]) - if override is not None: - for key, value in override.items(): - cfg[key] = value - config = RetargetingConfig(**cfg) + return cls.from_dict(cfg, override) + + @classmethod + def from_dict(cls, cfg: Dict[str, Any], override: Optional[Dict] = None): + if "target_link_human_indices" in cfg: + cfg["target_link_human_indices"] = np.array(cfg["target_link_human_indices"]) + if override is not None: + for key, value in override.items(): + cfg[key] = value + config = RetargetingConfig(**cfg) return config def build(self, scene: Optional[sapien.Scene] = None) -> SeqRetargeting: diff --git a/tests/test_retargeting_config.py b/tests/test_retargeting_config.py index 19bfaeb..2f6c81d 100644 --- a/tests/test_retargeting_config.py +++ b/tests/test_retargeting_config.py @@ -1,6 +1,7 @@ from pathlib import Path import pytest +import yaml from dex_retargeting.retargeting_config import RetargetingConfig from dex_retargeting.seq_retarget import SeqRetargeting @@ -36,8 +37,71 @@ class TestRetargetingConfig: ) @pytest.mark.parametrize("config_path", config_paths) - def test_config_parsing(self, config_path): + def test_path_config_parsing(self, config_path): config_path = self.config_dir / config_path config = RetargetingConfig.load_from_file(config_path) retargeting = config.build() + assert isinstance(retargeting, SeqRetargeting) + + def test_dict_config_parsing(self): + cfg_str = """ + type: vector + urdf_path: allegro_hand/allegro_hand_right.urdf + wrist_link_name: "wrist" + + # Target refers to the retargeting target, which is the robot hand + target_joint_names: null + target_origin_link_names: [ "wrist", "wrist", "wrist", "wrist" ] + target_task_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ] + scaling_factor: 1.6 + + # Source refers to the retargeting input, which usually corresponds to the human hand + # The joint indices of human hand joint which corresponds to each link in the target_link_names + target_link_human_indices: [ [ 0, 0, 0, 0 ], [ 4, 8, 12, 16 ] ] + + # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency + low_pass_alpha: 0.2 + """ + cfg_dict = yaml.safe_load(cfg_str) + config = RetargetingConfig.from_dict(cfg_dict) + retargeting = config.build() assert type(retargeting) == SeqRetargeting + + def test_multi_dict_config_parsing(self): + cfg_str = """ + - type: vector + urdf_path: allegro_hand/allegro_hand_right.urdf + wrist_link_name: "wrist" + + # Target refers to the retargeting target, which is the robot hand + target_joint_names: null + target_origin_link_names: [ "wrist", "wrist", "wrist", "wrist" ] + target_task_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ] + scaling_factor: 1.6 + + # Source refers to the retargeting input, which usually corresponds to the human hand + # The joint indices of human hand joint which corresponds to each link in the target_link_names + target_link_human_indices: [ [ 0, 0, 0, 0 ], [ 4, 8, 12, 16 ] ] + + # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency + low_pass_alpha: 0.2 + + - type: DexPilot + urdf_path: leap_hand/leap_hand_right.urdf + wrist_link_name: "base" + + # Target refers to the retargeting target, which is the robot hand + target_joint_names: null + finger_tip_link_names: [ "thumb_tip_head", "index_tip_head", "middle_tip_head", "ring_tip_head" ] + scaling_factor: 1.6 + + # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency + low_pass_alpha: 0.2 + """ + cfg_dict_list = yaml.safe_load(cfg_str) + retargetings = [] + for cfg_dict in cfg_dict_list: + config = RetargetingConfig.from_dict(cfg_dict) + retargeting = config.build() + retargetings.append(retargeting) + assert isinstance(retargeting, SeqRetargeting)