Skip to content

Commit

Permalink
[update] update video detection pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
yzqin committed Aug 26, 2023
1 parent d799bf6 commit edd8220
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 129 deletions.
39 changes: 39 additions & 0 deletions dex_retargeting/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import enum
from pathlib import Path


class RobotName(enum.Enum):
allegro = enum.auto()
shadow = enum.auto()
svh = enum.auto()


class RetargetingType(enum.Enum):
vector = enum.auto()
position = enum.auto()
dexpilot = enum.auto()


class HandType(enum.Enum):
right = enum.auto()
left = enum.auto()


ROBOT_NAME_MAP = {
RobotName.allegro: "allegro_hand",
RobotName.shadow: "shadow_hand",
RobotName.svh: "schunk_svh_hand",
}


def get_config_path(robot_name: RobotName, retargeting_type: RetargetingType, hand_type: HandType) -> Path:
config_path = Path(__file__).parent / "configs"
if retargeting_type is RetargetingType.position:
config_path = config_path / "offline"
else:
config_path = config_path / "teleop"

robot_name_str = ROBOT_NAME_MAP[robot_name]
hand_type_str = hand_type.name
config_name = f"{robot_name_str}_{hand_type_str}.yml"
return config_path / config_name
7 changes: 3 additions & 4 deletions dex_retargeting/retargeting_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import List, Optional, Dict
from pathlib import Path
from typing import List, Optional, Dict
from typing import Union

import numpy as np
import yaml
Expand Down Expand Up @@ -97,7 +98,7 @@ def set_default_urdf_dir(cls, urdf_dir):
cls._DEFAULT_URDF_DIR = urdf_dir

@classmethod
def load_from_file(cls, config_path, override: Optional[Dict] = None):
def load_from_file(cls, config_path: Union[str, Path], override: Optional[Dict] = None):
path = Path(config_path)
if not path.is_absolute():
path = path.absolute()
Expand All @@ -118,7 +119,6 @@ def build(self) -> SeqRetargeting:
VectorOptimizer,
PositionOptimizer,
DexPilotAllegroOptimizer,
DexPilotAllegroV4Optimizer,
)
from dex_retargeting.optimizer_utils import SAPIENKinematicsModelStandalone
from dex_retargeting import yourdfpy as urdf
Expand Down Expand Up @@ -191,7 +191,6 @@ def get_retargeting_config(config_path) -> RetargetingConfig:

if __name__ == "__main__":
# Path below is relative to this file
from pathlib import Path

test_config = get_retargeting_config(str(Path(__file__).parent / "configs/allegro_hand.yml"))
print(test_config)
Expand Down
177 changes: 52 additions & 125 deletions example/detect_from_video.py
Original file line number Diff line number Diff line change
@@ -1,146 +1,73 @@
import pickle
from pathlib import Path

import cv2
import numpy as np
import sapien.core as sapien
from sapien.asset import create_dome_envmap
import tqdm
import tyro

from dex_retargeting.retargeting_config import get_retargeting_config, RetargetingConfig
from dex_retargeting.constants import RobotName, RetargetingType, HandType, get_config_path
from dex_retargeting.retargeting_config import RetargetingConfig
from dex_retargeting.seq_retarget import SeqRetargeting
from single_hand_detector import SingleHandDetector

RECORD_VIDEO = False

def retarget_video(retargeting: SeqRetargeting, video_path: str, output_path: str, config_path: str):
cap = cv2.VideoCapture(video_path)

def setup_sapien_viz_scene(urdf_path):
from sapien.utils.viewer import Viewer
import sapien.core as sapien

engine = sapien.Engine()
engine.set_log_level("warning")

if not RECORD_VIDEO:
sapien.render_config.camera_shader_dir = "rt"
sapien.render_config.viewer_shader_dir = "rt"
sapien.render_config.rt_samples_per_pixel = 32
sapien.render_config.rt_use_denoiser = True

renderer = sapien.SapienRenderer()

engine.set_renderer(renderer)

scene_config = sapien.SceneConfig()
scene = engine.create_scene(scene_config)
scene.set_timestep(1 / 240)

scene.set_environment_map(create_dome_envmap(sky_color=[0.2, 0.2, 0.2], ground_color=[0.2, 0.2, 0.2]))
scene.add_directional_light([-1, 0.5, -1], color=[2.0, 2.0, 2.0], shadow=True, scale=2.0, shadow_map_size=4096)

loader = scene.create_urdf_loader()
robot = loader.load(urdf_path)
if robot_name == "shadow":
robot.set_pose(sapien.Pose([0, 0, -0.3]))
elif robot_name == "schunk":
robot.set_pose(sapien.Pose([0, 0, -0.05]))
elif robot_name == "dlr":
robot.set_pose(sapien.Pose([0, 0, -0.08]))

camera = scene.add_camera(name="photo", width=1280, height=720, fovy=1, near=0.1, far=10)
camera.set_local_pose(
sapien.Pose([0.313487, 0.0653831, -0.0111697], [0.088142, -0.0298786, -0.00264502, -0.995656])
)

if RECORD_VIDEO:
viewer = Viewer(renderer)
viewer.set_scene(scene)
# viewer.set_camera_pose(camera.pose)
else:
viewer = None

for actor in robot.get_links():
for visual in actor.get_visual_bodies():
for mesh in visual.get_render_shapes():
mat = mesh.material
mat.set_base_color(np.array([0.3, 0.3, 0.3, 1]))
mat.set_specular(0.7)
mat.set_metallic(0.1)

return scene, viewer


def build_retargeting(robot_name):
if robot_name == "allegro":
config_path = "teleop/allegro_hand_right.yml"
elif robot_name == "shadow":
config_path = "teleop/shadow_hand_right.yml"
elif robot_name == "schunk":
config_path = "teleop/schunk_svh_hand_right.yml"
else:
raise ValueError(f"Unrecognized robot_name: {robot_name}")

test_config = get_retargeting_config(str(Path(__file__).parent.parent / f"dex_retargeting/configs/{config_path}"))
seq_retargeting = test_config.build()
return seq_retargeting, test_config


def retarget_video(seq_retargeting: SeqRetargeting, scene: sapien.Scene, viewer):
video_path = Path(__file__).parent / "data/human_hand_video.mp4"
cap = cv2.VideoCapture(str(video_path))
robot = scene.get_all_articulations()[0]

if not RECORD_VIDEO:
writer = cv2.VideoWriter(f"data/output_{robot_name}.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 30.0, (1280, 720))
data = []

if not cap.isOpened():
print("Error: Could not open video file.")
else:
detector = SingleHandDetector(hand_type="Right", selfie=False)
while cap.isOpened():
ret, frame = cap.read()

if not ret:
break

rgb = frame[..., ::-1]
num_box, joint_pos, keypoint_2d, mediapipe_wrist_rot = detector.detect(rgb)

retargeting_type = seq_retargeting.optimizer.retargeting_type
indices = seq_retargeting.optimizer.target_link_human_indices
if retargeting_type == "VECTOR":
origin_indices = indices[0, :]
task_indices = indices[1, :]
ref_value = joint_pos[task_indices, :] - joint_pos[origin_indices, :]
elif retargeting_type == "POSITION":
indices = indices
ref_value = joint_pos[indices, :]
else:
raise ValueError(f"Unknown retargeting type: {retargeting_type}")
qpos = retargeting.retarget(ref_value)
robot.set_qpos(qpos)
if RECORD_VIDEO:
for _ in range(3):
viewer.render()
else:
cam = scene.get_cameras()[0]
scene.update_render()
cam.take_picture()
rgb = cam.get_texture("Color")[..., :3]
rgb = (np.clip(rgb, 0, 1) * 255).astype(np.uint8)
seg = cam.get_visual_actor_segmentation()[..., 0] < 1
rgb[seg, :] = [255, 255, 255]
writer.write(rgb[..., ::-1])
print(seq_retargeting.num_retargeting)
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
with tqdm.tqdm(total=length) as pbar:
while cap.isOpened():
ret, frame = cap.read()

if not ret:
break

rgb = frame[..., ::-1]
num_box, joint_pos, keypoint_2d, mediapipe_wrist_rot = detector.detect(rgb)

retargeting_type = retargeting.optimizer.retargeting_type
indices = retargeting.optimizer.target_link_human_indices
if retargeting_type == "POSITION":
indices = indices
ref_value = joint_pos[indices, :]
else:
origin_indices = indices[0, :]
task_indices = indices[1, :]
ref_value = joint_pos[task_indices, :] - joint_pos[origin_indices, :]
qpos = retargeting.retarget(ref_value)
data.append(qpos)

meta_data = dict(
config_path=config_path,
dof=retargeting.optimizer.dof,
joint_names=retargeting.optimizer.target_joint_names,
)

output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("wb") as f:
pickle.dump(dict(data=data, meta_data=meta_data), f)
pbar.update(1)

cap.release()
cv2.destroyAllWindows()
if not RECORD_VIDEO:
writer.release()


if __name__ == "__main__":
robot_name = ["allegro", "shadow", "schunk"][1]
def main(
robot_name: RobotName, video_path: str, output_path: str, retargeting_type: RetargetingType, hand_type: HandType
):
config_path = get_config_path(robot_name, retargeting_type, hand_type)
robot_dir = Path(__file__).parent.parent / "assets" / "robots"
RetargetingConfig.set_default_urdf_dir(str(robot_dir))
retargeting, cfg = build_retargeting(robot_name)
scene, viewer = setup_sapien_viz_scene(cfg.urdf_path)
retarget_video(retargeting, scene, viewer)
retargeting = RetargetingConfig.load_from_file(config_path).build()
retarget_video(retargeting, video_path, output_path, str(config_path))


if __name__ == "__main__":
tyro.cli(main)

0 comments on commit edd8220

Please sign in to comment.