Skip to content

Commit

Permalink
[add] add dummy joint support again with pinocchio backend
Browse files Browse the repository at this point in the history
  • Loading branch information
yzqin committed May 17, 2024
1 parent 7c4c3b2 commit c7692d0
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 30 deletions.
4 changes: 0 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,3 @@ imgui.ini
.DS_Store
/.idea
/log

# Examples
/example/vector_retargeting/data/
!example/vector_retargeting/data/human_hand_video.mp4
25 changes: 12 additions & 13 deletions dex_retargeting/retargeting_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import numpy as np
import yaml

from dex_retargeting import yourdfpy as urdf
from dex_retargeting.kinematics_adaptor import MimicJointKinematicAdaptor
from dex_retargeting.optimizer_utils import LPFilter
from dex_retargeting.robot_wrapper import RobotWrapper
from dex_retargeting.seq_retarget import SeqRetargeting
from dex_retargeting import yourdfpy as urdf
from dex_retargeting.kinematics_adaptor import MimicJointKinematicAdaptor
from dex_retargeting.yourdfpy import DUMMY_JOINT_NAMES


@dataclass
Expand Down Expand Up @@ -136,15 +137,22 @@ def build(self) -> SeqRetargeting:
import tempfile

# Process the URDF with yourdfpy to better find file path
robot_urdf = urdf.URDF.load(self.urdf_path, build_scene_graph=False)
robot_urdf = urdf.URDF.load(
self.urdf_path, add_dummy_free_joints=self.add_dummy_free_joint, build_scene_graph=False
)
urdf_name = self.urdf_path.split("/")[-1]
temp_dir = tempfile.mkdtemp(prefix="dex_retargeting-")
temp_path = f"{temp_dir}/{urdf_name}"
robot_urdf.write_xml_file(temp_path)

# Load pinocchio model
robot = RobotWrapper(temp_path)

# Add 6D dummy joint to target joint names so that it will also be optimized
if self.add_dummy_free_joint and self.target_joint_names is not None:
self.target_joint_names = DUMMY_JOINT_NAMES + self.target_joint_names
joint_names = self.target_joint_names if self.target_joint_names is not None else robot.dof_joint_names

if self.type == "position":
optimizer = PositionOptimizer(
robot,
Expand Down Expand Up @@ -210,7 +218,7 @@ def build(self) -> SeqRetargeting:
return retargeting


def get_retargeting_config(config_path) -> RetargetingConfig:
def get_retargeting_config(config_path: Union[str, Path]) -> RetargetingConfig:
config = RetargetingConfig.load_from_file(config_path)
return config

Expand All @@ -228,12 +236,3 @@ def parse_mimic_joint(robot_urdf: urdf.URDF) -> Tuple[bool, List[str], List[str]
offsets.append(joint.mimic.offset)

return len(mimic_joint_names) > 0, source_joint_names, mimic_joint_names, multipliers, offsets


if __name__ == "__main__":
# Path below is relative to this file

test_config = get_retargeting_config(str(Path(__file__).parent / "configs/allegro_hand.yml"))
print(test_config)
opt = test_config.build()
print(opt.optimizer.target_link_human_indices)
2 changes: 0 additions & 2 deletions dex_retargeting/seq_retarget.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def warm_start(self, wrist_pos: np.ndarray, wrist_orientation: np.ndarray, globa
wrist_orientation: orientation of the hand orientation, typically from human hand pose in MANO convention
global_rot:
Returns:
"""
# This function can only be used when the first joints of robot are free joints
if len(wrist_pos) != 3:
Expand Down
65 changes: 62 additions & 3 deletions dex_retargeting/yourdfpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ def _validate_required_attribute(self, attribute, error_msg, allowed_values=None
self._errors.append(URDFAttributeValueError(error_msg))

@staticmethod
def load(fname_or_file, **kwargs):
def load(fname_or_file, add_dummy_free_joints=False, **kwargs):
"""Load URDF file from filename or file object.
Args:
Expand Down Expand Up @@ -947,7 +947,9 @@ def load(fname_or_file, **kwargs):
etree.strip_tags(xml_root, etree.Comment)
etree.cleanup_namespaces(xml_root)

return URDF(robot=URDF._parse_robot(xml_element=xml_root), **kwargs)
return URDF(
robot=URDF._parse_robot(xml_element=xml_root, add_dummy_free_joints=add_dummy_free_joints), **kwargs
)

def contains(self, key, value, element=None) -> bool:
"""Checks recursively whether the URDF tree contains the provided key-value pair.
Expand Down Expand Up @@ -2060,7 +2062,7 @@ def _write_joint(self, xml_parent, joint):
self._write_dynamics(xml_element, joint.dynamics)

@staticmethod
def _parse_robot(xml_element):
def _parse_robot(xml_element, add_dummy_free_joints=False):
robot = Robot(name=xml_element.attrib["name"])

for l in xml_element.findall("link"):
Expand All @@ -2069,6 +2071,19 @@ def _parse_robot(xml_element):
robot.joints.append(URDF._parse_joint(j))
for m in xml_element.findall("material"):
robot.materials.append(URDF._parse_material(m))

if add_dummy_free_joints:
# Determine root link
link_names = [l.name for l in robot.links]
for j in robot.joints:
link_names.remove(j.child)

if len(link_names) == 0:
raise RuntimeError(f"No root link found for robot.")

root_link_name = link_names[0]
_add_dummy_joints(robot, root_link_name)

return robot

def _validate_robot(self, robot):
Expand Down Expand Up @@ -2176,3 +2191,47 @@ def get_link_global_transform(self, link_name):
node = anytree.search.findall_by_attr(self.tree_root, link_name)[0]

return node.global_pose


def _add_dummy_joints(robot: Robot, root_link_name: str):
# Prepare link and joint properties
translation_range = (-5, 5)
rotation_range = (-2 * np.pi, 2 * np.pi)
joint_types = ["prismatic"] * 3 + ["revolute"] * 3
joint_limit = [translation_range] * 3 + [rotation_range] * 3
joint_name = DUMMY_JOINT_NAMES.copy()
link_name = [f"dummy_{name}_translation_link" for name in "xyz"] + [f"dummy_{name}_rotation_link" for name in "xyz"]

links = []
joints = []

for i in range(6):
inertial = Inertial(
mass=0.01, inertia=np.array([[1e-4, 0, 0], [0, 1e-4, 0], [0, 0, 1e-4]]), origin=np.identity(4)
)
link = Link(name=link_name[i], inertial=inertial)
links.append(link)

joint_axis = np.zeros(3, dtype=int)
joint_axis[i % 3] = 1
limit = Limit(lower=joint_limit[i][0], upper=joint_limit[i][1], velocity=3.14, effort=10)

child_name = link_name[i + 1] if i < 5 else root_link_name
joint = Joint(
name=joint_name[i],
type=joint_types[i],
parent=link_name[i],
child=child_name,
origin=np.identity(4),
axis=joint_axis,
limit=limit,
)
joints.append(joint)

robot.joints = joints + robot.joints
robot.links = links + robot.links


DUMMY_JOINT_NAMES = [f"dummy_{name}_translation_joint" for name in "xyz"] + [
f"dummy_{name}_rotation_joint" for name in "xyz"
]
3 changes: 3 additions & 0 deletions example/vector_retargeting/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Examples
/example/vector_retargeting/data/
!example/vector_retargeting/data/human_hand_video.mp4
30 changes: 22 additions & 8 deletions tests/test_retargeting_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ def test_dict_config_parsing(self):
target_link_human_indices: [ 4, 8, 12, 16, 20 ]
# A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
# 1 means no filter while 0 means not moving
low_pass_alpha: 1
"""
cfg_dict = yaml.safe_load(cfg_str)
Expand All @@ -75,29 +73,24 @@ def test_multi_dict_config_parsing(self):
urdf_path: allegro_hand/allegro_hand_right.urdf
wrist_link_name: "wrist"
# Target refers to the retargeting target, which is the robot hand
target_joint_names: null
target_origin_link_names: [ "wrist", "wrist", "wrist", "wrist" ]
target_task_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ]
scaling_factor: 1.6
# Source refers to the retargeting input, which usually corresponds to the human hand
# The joint indices of human hand joint which corresponds to each link in the target_link_names
target_link_human_indices: [ [ 0, 0, 0, 0 ], [ 4, 8, 12, 16 ] ]
# A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
low_pass_alpha: 0.2
- type: DexPilot
urdf_path: leap_hand/leap_hand_right.urdf
wrist_link_name: "base"
# Target refers to the retargeting target, which is the robot hand
target_joint_names: null
finger_tip_link_names: [ "thumb_tip_head", "index_tip_head", "middle_tip_head", "ring_tip_head" ]
scaling_factor: 1.6
# A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
low_pass_alpha: 0.2
"""
cfg_dict_list = yaml.safe_load(cfg_str)
Expand All @@ -107,3 +100,24 @@ def test_multi_dict_config_parsing(self):
retargeting = config.build()
retargetings.append(retargeting)
assert isinstance(retargeting, SeqRetargeting)

@pytest.mark.parametrize("config_path", POSITION_CONFIG_DICT.values())
def test_add_dummy_joint(self, config_path):
config_path = self.config_dir / config_path
override = {"add_dummy_free_joint": False}
config = RetargetingConfig.load_from_file(config_path, override)
retargeting = config.build()
robot = retargeting.optimizer.robot
original_robot_dof = robot.dof
original_active_dof = len(retargeting.optimizer.target_joint_names)

override = {"add_dummy_free_joint": True}
config = RetargetingConfig.load_from_file(config_path, override)
retargeting = config.build()
robot = retargeting.optimizer.robot

assert robot.dof == original_robot_dof + 6
assert retargeting.joint_limits.shape == (original_active_dof + 6, 2)
dummy_joint_names = robot.dof_joint_names[:6]
for i in range(6):
assert "dummy" in dummy_joint_names[i]

0 comments on commit c7692d0

Please sign in to comment.