Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/changhaonan/LGMCTS-D
Browse files Browse the repository at this point in the history
  • Loading branch information
Kowndinya2000 committed Sep 13, 2023
2 parents 314ffb5 + 033a3e3 commit 12e5be3
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 308 deletions.
136 changes: 67 additions & 69 deletions lgmcts/algorithm/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ class Sampler:
"""
manipulating_object, aligning_object, direction
"""

def __init__(self, obj_name, origin_name, direction, region: Region2DSampler):
self.obj_name = obj_name
self.origin_name = origin_name
self.direction = direction
self.region = region


class Node(object):
"""MCTS Node"""
Expand All @@ -41,10 +42,10 @@ def __init__(
updated_obj_id=None,
UCB_scalar=1.0,
num_sampling=1,
obj_support_tree:anytree.Node = None,
obj_support_tree: anytree.Node = None,
prior_dict={},
verbose=False,
rng = None
rng=None
) -> None:

self.node_id = node_id
Expand All @@ -67,24 +68,24 @@ def __init__(
self.verbose = verbose
self.rng = rng

self.segmentation = None # segmentation of the workspace, will be generated only once when needed
self.segmentation = None # segmentation of the workspace, will be generated only once when needed

def generate_actions(self):
"""
generate the list of actions for this node.
That is, what samplers can be sampled
without breaking the pattern ordering
"""
no_sample_objs = set() # objects that cannot be sampled because of ordering
for obj_id, sampler in self.sampler_dict.items(): # check ordering
no_sample_objs = set() # objects that cannot be sampled because of ordering
for obj_id, sampler in self.sampler_dict.items(): # check ordering
ordered = sampler.sample_info.get("ordered", False)
if ordered:
prior_objs = sampler.obj_ids[:sampler.obj_ids.index(obj_id)]
for prior_obj in prior_objs:
if prior_obj in self.sampler_dict:
no_sample_objs.add(obj_id)
break

return [obj_id for obj_id in self.sampler_dict.keys() if obj_id not in no_sample_objs]

def UCB(self):
Expand Down Expand Up @@ -133,45 +134,45 @@ def action_parametriczation(self, action):
# check graspability
found_node = anytree.search.find(
self.obj_support_tree, lambda node: node.name == action[0]
)
)
if found_node and len(found_node.children) > 0:
# not graspable, move a leave on the subtree away
# Search for all leaf nodes
leaf_nodes = found_node.leaves
moved_obj = self.rng.choice(leaf_nodes).name
# add a sampler to move the obstacle away
buffer_sampler = SampleData(
pattern="line",
obj_id = moved_obj,
obj_ids = [moved_obj],
obj_poses_pix = {})
pattern="line",
obj_id=moved_obj,
obj_ids=[moved_obj],
obj_poses_pix={})
success, _, (moved_obj, new_position) = self.sampling_function(
self.region_sampler,
self.object_states,
buffer_sampler
)
solved_sampler_obj_id = float('inf')
return action, moved_obj, new_position, solved_sampler_obj_id

sampler = self.sampler_dict[action[0]]
success, obs, (moved_obj, new_position) = self.sampling_function(
self.region_sampler,
self.object_states,
sampler,
)
solved_sampler_obj_id, _ = action
if not success: # fails to complete the sampling, do
if not success: # fails to complete the sampling, do
if obs is None:
# fails but not because of collision (e.g., out of workspace)
solved_sampler_obj_id = float('inf')
moved_obj = None
else:
# add a sampler to move the obstacle away
buffer_sampler = SampleData(
pattern="line",
obj_id = obs,
obj_ids = [obs],
obj_poses_pix = {})
pattern="line",
obj_id=obs,
obj_ids=[obs],
obj_poses_pix={})
success, _, (moved_obj, new_position) = self.sampling_function(
self.region_sampler,
self.object_states,
Expand All @@ -181,12 +182,12 @@ def action_parametriczation(self, action):
return action, moved_obj, new_position, solved_sampler_obj_id

def sampling_function(
self,
region: Region2DSampler,
object_states: dict,
sample_data: SampleData,
verbose: bool = False,
):
self,
region: Region2DSampler,
object_states: dict,
sample_data: SampleData,
verbose: bool = False,
):
"""
sampling function
If sampling succeeded, return True, None, (moved_obj_id, new_pose)
Expand All @@ -197,51 +198,51 @@ def sampling_function(
success, obs_name, action:(obj_name, new_pos)
"""
obj_id = sample_data.obj_id

# update region
region.set_object_poses(obj_states=object_states)
# region.visualize()
# keep track of sampled object poses
sampled_obj_poses_pix = {}
sampled_obj_poses_pix = {}
pattern_objs = sample_data.obj_ids # objects involved in the sampling pattern
objs_away_from_goal = list(self.sampler_dict.keys()) # pattern objects away from goal
objs_at_goal = [
pattern_obj for pattern_obj in pattern_objs
if (pattern_obj != obj_id) and (pattern_obj not in objs_away_from_goal)
] # pattern objects at goal
#FIXME: this could be a problem here, because there is an offset
pattern_obj for pattern_obj in pattern_objs
if (pattern_obj != obj_id) and (pattern_obj not in objs_away_from_goal)
] # pattern objects at goal
# FIXME: this could be a problem here, because there is an offset
sampled_obj_poses_pix = {
obj:region._world2pix(object_states[obj][:3] + region.objects[obj].pos_offset)
obj: region._world2pix(object_states[obj][:3] + region.objects[obj].pos_offset)
for obj in objs_at_goal}

# update prior
if sample_data.pattern in self.prior_dict:
prior, pattern_info = self.prior_dict[sample_data.pattern].gen_prior(
region.grid_size, region.rng,
obj_id=sample_data.obj_id,
region.grid_size, region.rng,
obj_id=sample_data.obj_id,
obj_ids=sample_data.obj_ids,
obj_poses_pix=sampled_obj_poses_pix,
sample_info = sample_data.sample_info
)
sample_info=sample_data.sample_info
)
# cv2.imshow("prior", prior)
# cv2.waitKey(0)
# the prior object is too close to the boundary so that no sampling is possible
if np.sum(prior) <= 0:
obs = self.rng.choice([obj for obj in sample_data.obj_ids if obj != obj_id])
return False, obs, (obj_id, None)
# sample
valid_pose, _, samples_status, _ = region.sample(sample_data.obj_id, 1, prior, allow_outside=False)
valid_pose, _, samples_status, _ = region.sample(sample_data.obj_id, 1, prior, allow_outside=False, pattern_info=pattern_info)
if valid_pose.shape[0] > 0:
valid_pose = valid_pose.reshape(-1)
else:
raise NotImplementedError

success = samples_status == SampleStatus.SUCCESS

# test
# print(f"sample status: {samples_status.name}, valid_pose: {valid_pose}")

if not success: # find an obstacle
if not success: # find an obstacle
if self.segmentation is None:
self.segmentation = self.semantic_segmentation(region)
leaf_nodes = self.obj_support_tree.leaves
Expand All @@ -250,7 +251,7 @@ def sampling_function(
while (counter > 0):
counter -= 1
sample_pix, sample_probs = sample_distribution(prob=prior, rng=region.rng, n_samples=1) # (N, 2)
obs_id = self.segmentation[sample_pix[0][0], sample_pix[0][1], 0]
obs_id = self.segmentation[sample_pix[0][0], sample_pix[0][1], 0]
if (obs_id not in [-1, obj_id]) and (obs_id in leaf_objs):
break
if counter <= 0:
Expand All @@ -261,16 +262,17 @@ def sampling_function(

return success, obs_id, action

def semantic_segmentation(self, region:Region2DSampler):
#TODO: Merge this part into sampler
def semantic_segmentation(self, region: Region2DSampler):
# TODO: Merge this part into sampler
# semetic segmentation of the workspace
segmentation = -1.0 * np.ones((region.grid_size[0], region.grid_size[1], 3), dtype=np.float32)
# objects
for obj_id, obj_data in region.objects.items():
region._put_mask(
mask=obj_data.mask,
pos=obj_data.pos,
occupancy_map=segmentation,
rot=obj_data.rot,
region_map=segmentation,
value=float(obj_id),
)
return segmentation
Expand All @@ -292,22 +294,22 @@ def __init__(
region_sampler: Region2DSampler,
L: List[SampleData],
UCB_scalar=1.0,
obj_support_tree:anytree.Node = None,
obj_support_tree: anytree.Node = None,
prior_dict={},
n_samples = 1,
n_samples=1,
verbose: bool = False,
seed = 0
seed=0
) -> None:
self.rng = np.random.default_rng(seed=seed)
self.settings = {
"UCB_scalar": UCB_scalar,
"prior_dict": prior_dict,
"rng": self.rng,
"num_sampling" : n_samples
"num_sampling": n_samples
}
self.region_sampler = region_sampler
self.sampler_dict = {s.obj_id: s for s in L}
self.obj_support_tree = obj_support_tree # initial object support tree
self.obj_support_tree = obj_support_tree # initial object support tree
self.start_state = region_sampler.get_object_poses()

# intialize MCTS tree
Expand Down Expand Up @@ -355,14 +357,14 @@ def search(self, max_iter: int = 10000, log_step: int = 1000) -> bool:

while num_iter < max_iter:
if (num_iter % log_step) == 0:
print(num_iter)
print(f"Searched {num_iter}/{max_iter} iterations")
num_iter += 1
current_node = self.selection()
# an action in MCTS is represented by (sampler_id, trail_id),
# an action in MCTS is represented by (sampler_id, trail_id),
# the index is according to L and the num_sample children list
#TODO: do K sampling at the same time @KAI
# TODO: do K sampling at the same time @KAI
action, moved_obj, new_position, solved_sampler_obj_id = current_node.expansion()
if (new_position.shape[0] > 0): # go to a new state
if (new_position.shape[0] > 0): # go to a new state
new_node = self.move(
num_iter,
action,
Expand All @@ -371,7 +373,7 @@ def search(self, max_iter: int = 10000, log_step: int = 1000) -> bool:
solved_sampler_obj_id,
current_node,
)
else: # stay in the same state
else: # stay in the same state
new_node = Node(
num_iter,
region_sampler=self.region_sampler,
Expand All @@ -391,7 +393,7 @@ def search(self, max_iter: int = 10000, log_step: int = 1000) -> bool:
current_node.children[action[0]].append(new_node)

# update reward
#TODO: new reward function @KAI
# TODO: new reward function @KAI
reward = self.reward_detection(new_node)
self.back_propagation(new_node, reward)
if reward == len(self.sampler_dict):
Expand Down Expand Up @@ -425,9 +427,9 @@ def move(
}
# print(f"id: {node_id}, obj_states: {new_object_states}, target: {target}")

new_sampler_dict = {obj_id:sampler for obj_id, sampler in current_node.sampler_dict.items() if obj_id != solved_sampler_obj_id}
# If we are moving an obstacle, the moved object may be an object moved to goal,
new_sampler_dict = {obj_id: sampler for obj_id, sampler in current_node.sampler_dict.items() if obj_id != solved_sampler_obj_id}

# If we are moving an obstacle, the moved object may be an object moved to goal,
# we need to retrive the sampler to indicate that this sampler needs to be solved again
if solved_sampler_obj_id == float("inf"):
backtracked_node = current_node
Expand Down Expand Up @@ -484,29 +486,25 @@ def construct_plan(self, node: Node):
moved_object = current_node.updated_obj_id
# current_node.show_arrangement()
if moved_object is not None:
old_pose = np.concatenate(
[parent_node.object_states[moved_object][:3], np.array([0, 0, 0, 1])],
axis=0).reshape(-1).astype(np.float32)
new_pose = np.concatenate(
[current_node.object_states[moved_object][:3], np.array([0, 0, 0, 1])],
axis=0).reshape(-1).astype(np.float32)
old_pose = parent_node.object_states[moved_object].reshape(-1).astype(np.float32)
new_pose = current_node.object_states[moved_object].reshape(-1).astype(np.float32)
self.action_list.append(
{
"obj_id": moved_object,
"old_pose": old_pose,
"new_pose": new_pose,
}
"obj_id": moved_object,
"old_pose": old_pose,
"new_pose": new_pose,
}
)
current_node = parent_node
self.action_list.reverse()


# copy anytree
def copy_tree(node:anytree.Node):
def copy_tree(node: anytree.Node):
copied_node = anytree.Node(copy.deepcopy(node.name))

for child in node.children:
child_copy = copy_tree(child)
child_copy.parent = copied_node

return copied_node
return copied_node
Loading

0 comments on commit 12e5be3

Please sign in to comment.