Skip to content

Commit

Permalink
[update] update to v0.0.2, support warm start, which initialize the h…
Browse files Browse the repository at this point in the history
…and wrist pose based on human root pose. This is helpful for non-sequential position retargeting
  • Loading branch information
yzqin committed Sep 27, 2023
1 parent 32da0d0 commit dcef611
Show file tree
Hide file tree
Showing 16 changed files with 112 additions and 20 deletions.
2 changes: 1 addition & 1 deletion dex_retargeting/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.1"
__version__ = "0.0.2"
6 changes: 4 additions & 2 deletions dex_retargeting/configs/offline/allegro_hand_right.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
retargeting:
type: position
urdf_path: allegro_hand/allegro_hand_right.urdf
wrist_link_name: "wrist"

target_joint_names: null
target_link_names: [ "palm", "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ]
target_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip", "link_14.0",
"link_2.0", "link_6.0", "link_10.0" ]

target_link_human_indices: [ 0, 4, 8, 12, 16 ]
target_link_human_indices: [ 4, 8, 12, 16, 2, 6, 10, 14 ]

# A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
# 1 means no filter while 0 means not moving
Expand Down
5 changes: 3 additions & 2 deletions dex_retargeting/configs/offline/schunk_svh_hand_right.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
retargeting:
type: position
urdf_path: schunk_hand/schunk_svh_hand_right.urdf
wrist_link_name: "right_hand_base_link"

target_joint_names: null
target_link_names: [ "right_hand_e1", "right_hand_c", "right_hand_t", "right_hand_s", "right_hand_r",
target_link_names: [ "right_hand_c", "right_hand_t", "right_hand_s", "right_hand_r",
"right_hand_q", "right_hand_b", "right_hand_p", "right_hand_o", "right_hand_n", "right_hand_i"]

target_link_human_indices: [ 0, 4, 8, 12, 16, 20, 2, 6, 10, 14, 18 ]
target_link_human_indices: [ 4, 8, 12, 16, 20, 2, 6, 10, 14, 18 ]

# A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
# 1 means no filter while 0 means not moving
Expand Down
5 changes: 3 additions & 2 deletions dex_retargeting/configs/offline/shadow_hand_right.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
retargeting:
type: position
urdf_path: shadow_hand/shadow_hand_right.urdf
wrist_link_name: "ee_link"

target_joint_names: null
target_link_names: [ "palm", "thtip", "fftip", "mftip", "rftip", "lftip",
target_link_names: [ "thtip", "fftip", "mftip", "rftip", "lftip",
"thmiddle", "ffmiddle", "mfmiddle", "rfmiddle", "lfmiddle" ]

target_link_human_indices: [ 0, 4, 8, 12, 16, 20, 2, 6, 10, 14, 18 ]
target_link_human_indices: [ 4, 8, 12, 16, 20, 2, 6, 10, 14, 18 ]

# A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
# 1 means no filter while 0 means not moving
Expand Down
1 change: 1 addition & 0 deletions dex_retargeting/configs/teleop/allegro_hand_left.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
retargeting:
type: vector
urdf_path: allegro_hand/allegro_hand_left.urdf
wrist_link_name: "wrist"

# Target refers to the retargeting target, which is the robot hand
target_joint_names: null
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
retargeting:
type: DexPilot
urdf_path: allegro_hand/allegro_hand_left.urdf
wrist_link_name: "wrist"

# Target refers to the retargeting target, which is the robot hand
target_joint_names: null
wrist_link_name: "wrist"
finger_tip_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ]
scaling_factor: 1.6

Expand Down
1 change: 1 addition & 0 deletions dex_retargeting/configs/teleop/allegro_hand_right.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
retargeting:
type: vector
urdf_path: allegro_hand/allegro_hand_right.urdf
wrist_link_name: "wrist"

# Target refers to the retargeting target, which is the robot hand
target_joint_names: null
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
retargeting:
type: DexPilot
urdf_path: allegro_hand/allegro_hand_right.urdf
wrist_link_name: "wrist"

# Target refers to the retargeting target, which is the robot hand
target_joint_names: null
wrist_link_name: "wrist"
finger_tip_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ]
scaling_factor: 1.6

Expand Down
1 change: 1 addition & 0 deletions dex_retargeting/configs/teleop/schunk_svh_hand_right.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
retargeting:
type: vector
urdf_path: schunk_hand/schunk_svh_hand_right.urdf
wrist_link_name: "right_hand_base_link"

# Target refers to the retargeting target, which is the robot hand
target_joint_names: null
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
retargeting:
type: DexPilot
urdf_path: schunk_hand/schunk_svh_hand_right.urdf
wrist_link_name: "right_hand_base_link"

# Target refers to the retargeting target, which is the robot hand
target_joint_names: null
wrist_link_name: "right_hand_base_link"
finger_tip_link_names: [ "right_hand_c", "right_hand_t", "right_hand_s", "right_hand_r", "right_hand_q" ]
scaling_factor: 1.2

Expand Down
1 change: 1 addition & 0 deletions dex_retargeting/configs/teleop/shadow_hand_right.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
retargeting:
type: vector
urdf_path: shadow_hand/shadow_hand_right.urdf
wrist_link_name: "ee_link"

# Target refers to the retargeting target, which is the robot hand
target_joint_names: null
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
retargeting:
type: DexPilot
urdf_path: shadow_hand/shadow_hand_right.urdf
wrist_link_name: "ee_link"

# Target refers to the retargeting target, which is the robot hand
target_joint_names: null
Expand Down
19 changes: 15 additions & 4 deletions dex_retargeting/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@ class Optimizer:
retargeting_type = "BASE"

def __init__(
self, robot: sapien.Articulation, target_joint_names: List[str], target_link_human_indices: np.ndarray
self,
robot: sapien.Articulation,
wrist_link_name: str,
target_joint_names: List[str],
target_link_human_indices: np.ndarray,
):
self.robot = robot
self.robot_dof = robot.dof
self.model = robot.create_pinocchio_model()
self.wrist_link_name = wrist_link_name

joint_names = [joint.get_name() for joint in robot.get_active_joints()]
target_joint_index = []
Expand All @@ -32,6 +37,10 @@ def __init__(
# Target
self.target_link_human_indices = target_link_human_indices

# Free joint
link_names = [link.get_name() for link in self.robot.get_links()]
self.has_free_joint = len([name for name in link_names if "dummy" in name]) >= 6

def set_joint_limit(self, joint_limits: np.ndarray):
if joint_limits.shape != (self.dof, 2):
raise ValueError(f"Expect joint limits have shape: {(self.dof, 2)}, but get {joint_limits.shape}")
Expand Down Expand Up @@ -72,13 +81,14 @@ class PositionOptimizer(Optimizer):
def __init__(
self,
robot: sapien.Articulation,
wrist_link_name: str,
target_joint_names: List[str],
target_link_names: List[str],
target_link_human_indices: np.ndarray,
huber_delta=0.02,
norm_delta=4e-3,
):
super().__init__(robot, target_joint_names, target_link_human_indices)
super().__init__(robot, wrist_link_name, target_joint_names, target_link_human_indices)
self.body_names = target_link_names
self.huber_loss = torch.nn.SmoothL1Loss(beta=huber_delta)
self.norm_delta = norm_delta
Expand Down Expand Up @@ -166,6 +176,7 @@ class VectorOptimizer(Optimizer):
def __init__(
self,
robot: sapien.Articulation,
wrist_link_name: str,
target_joint_names: List[str],
target_origin_link_names: List[str],
target_task_link_names: List[str],
Expand All @@ -174,7 +185,7 @@ def __init__(
norm_delta=4e-3,
scaling=1.0,
):
super().__init__(robot, target_joint_names, target_link_human_indices)
super().__init__(robot, wrist_link_name, target_joint_names, target_link_human_indices)
self.origin_link_names = target_origin_link_names
self.task_link_names = target_task_link_names
self.huber_loss = torch.nn.SmoothL1Loss(beta=huber_delta, reduction="mean")
Expand Down Expand Up @@ -326,7 +337,7 @@ def __init__(
target_origin_link_names = [link_names[index] for index in origin_link_index]
target_task_link_names = [link_names[index] for index in task_link_index]

super().__init__(robot, target_joint_names, target_link_human_indices)
super().__init__(robot, wrist_link_name, target_joint_names, target_link_human_indices)
self.origin_link_names = target_origin_link_names
self.task_link_names = target_task_link_names
self.scaling = scaling
Expand Down
12 changes: 9 additions & 3 deletions dex_retargeting/retargeting_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
class RetargetingConfig:
type: str
urdf_path: str

# The link on the robot hand which corresponding to the wrist of human hand
wrist_link_name: str

# Whether to add free joint to the root of the robot. Free joint enable the robot hand move freely in the space
add_dummy_free_joint: bool = False

# Source refers to the retargeting input, which usually corresponds to the human hand
# The joint indices of human hand joint which corresponds to each link in the target_link_names

target_link_human_indices: Optional[np.ndarray] = None

# Position retargeting link names
Expand All @@ -32,7 +36,6 @@ class RetargetingConfig:

# DexPilot retargeting link names
finger_tip_link_names: Optional[List[str]] = None
wrist_link_name: Optional[str] = None

# Scaling factor for vector retargeting only
# For example, Allegro is 1.6 times larger than normal human hand, then this scaling factor should be 1.6
Expand Down Expand Up @@ -128,7 +131,7 @@ def build(self, scene: Optional[sapien.Scene] = None) -> SeqRetargeting:
# Process the URDF with yourdfpy to better find file path
robot_urdf = urdf.URDF.load(self.urdf_path)
urdf_name = self.urdf_path.split("/")[-1]
temp_dir = tempfile.mkdtemp(prefix="teleop-")
temp_dir = tempfile.mkdtemp(prefix="dex_retargeting-")
temp_path = f"{temp_dir}/{urdf_name}"
robot_urdf.write_xml_file(temp_path)
sapien_model = SAPIENKinematicsModelStandalone(
Expand All @@ -138,6 +141,7 @@ def build(self, scene: Optional[sapien.Scene] = None) -> SeqRetargeting:
scene=scene,
)
robot = sapien_model.robot
robot.set_name(Path(self.urdf_path).stem)
joint_names = (
self.target_joint_names
if self.target_joint_names is not None
Expand All @@ -146,6 +150,7 @@ def build(self, scene: Optional[sapien.Scene] = None) -> SeqRetargeting:
if self.type == "position":
optimizer = PositionOptimizer(
robot,
self.wrist_link_name,
joint_names,
target_link_names=self.target_link_names,
target_link_human_indices=self.target_link_human_indices,
Expand All @@ -155,6 +160,7 @@ def build(self, scene: Optional[sapien.Scene] = None) -> SeqRetargeting:
elif self.type == "vector":
optimizer = VectorOptimizer(
robot,
self.wrist_link_name,
joint_names,
target_origin_link_names=self.target_origin_link_names,
target_task_link_names=self.target_task_link_names,
Expand Down
70 changes: 68 additions & 2 deletions dex_retargeting/seq_retarget.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
from time import time
from typing import Optional

import numpy as np
import transforms3d

from dex_retargeting.optimizer import Optimizer
from dex_retargeting.optimizer_utils import LPFilter
from typing import Optional


class SeqRetargeting:
Expand Down Expand Up @@ -34,9 +36,73 @@ def __init__(
# Filter
self.filter = lp_filter

# Warm started
self.is_warm_started = False

# TODO: hack here
self.scene = None

def warm_start(self, wrist_pos: np.ndarray, wrist_orientation: np.ndarray, global_rot: np.array):
# This function can only be used when the first joints of robot are free joints
if len(wrist_pos) != 3:
raise ValueError(f"Wrist pos:{wrist_pos} is not a 3-dim vector.")
if len(wrist_orientation) != 3:
raise ValueError(f"Wrist orientation:{wrist_orientation} is not a 3-dim vector.")

if np.linalg.norm(wrist_orientation) < 1e-3:
mat = np.eye(3)
else:
angle = np.linalg.norm(wrist_orientation)
axis = wrist_orientation / angle
mat = transforms3d.axangles.axangle2mat(axis, angle)
print(transforms3d.quaternions.axangle2quat(axis, angle))

robot = self.optimizer.robot
operator2mano = np.array([[0, 0, -1], [-1, 0, 0], [0, 1, 0]])
mat = global_rot.T @ mat @ operator2mano
target_wrist_pose = np.eye(4)
target_wrist_pose[:3, :3] = mat
target_wrist_pose[:3, 3] = wrist_pos

wrist_link_name = self.optimizer.wrist_link_name
wrist_link = [link for link in self.optimizer.robot.get_links() if link.get_name() == wrist_link_name][0]
name_list = [
"dummy_x_translation_joint",
"dummy_y_translation_joint",
"dummy_z_translation_joint",
"dummy_x_rotation_joint",
"dummy_y_rotation_joint",
"dummy_z_rotation_joint",
]
old_qpos = robot.get_qpos()
new_qpos = old_qpos.copy()
for num, joint_name in enumerate(self.optimizer.target_joint_names):
if joint_name in name_list:
new_qpos[num] = 0
robot.set_qpos(new_qpos)
root2wrist = (robot.get_pose().inv() * wrist_link.get_pose()).inv().to_transformation_matrix()
target_root_pose = target_wrist_pose @ root2wrist
robot.set_qpos(old_qpos)

euler = transforms3d.euler.mat2euler(target_root_pose[:3, :3], "rxyz")
pose_vec = np.concatenate([target_root_pose[:3, 3], euler])

# Find the dummy joints
name_list = [
"dummy_x_translation_joint",
"dummy_y_translation_joint",
"dummy_z_translation_joint",
"dummy_x_rotation_joint",
"dummy_y_rotation_joint",
"dummy_z_rotation_joint",
]
for num, joint_name in enumerate(self.optimizer.target_joint_names):
if joint_name in name_list:
index = name_list.index(joint_name)
self.last_qpos[num] = pose_vec[index]

self.is_warm_started = True

def retarget(self, ref_value, fixed_qpos=np.array([])):
tic = time()
qpos = self.optimizer.retarget(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
raise RuntimeError("Unable to find __version__ string.")

core_requirements = [
"numpy",
"numpy<1.24",
"torch",
"sapien>=2.0.0",
"nlopt",
Expand Down

0 comments on commit dcef611

Please sign in to comment.