Skip to content

Commit

Permalink
enable spatial pattern planning
Browse files Browse the repository at this point in the history
  • Loading branch information
gaokai15 committed Sep 3, 2023
1 parent 420867e commit 61cf90b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
18 changes: 13 additions & 5 deletions lgmcts/algorithm/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ def generate_actions(self):
"""
no_sample_objs = set() # objects that cannot be sampled because of ordering
for obj_id, sampler in self.sampler_dict.items(): # check ordering
if obj_id in no_sample_objs:
continue
if sampler.pattern in ORDERED_PATTERNS:
post_obj = sampler.obj_ids[sampler.obj_ids.index(obj_id)+1:]
no_sample_objs = no_sample_objs.union(set(post_obj))
prior_objs = sampler.obj_ids[:sampler.obj_ids.index(obj_id)]
for prior_obj in prior_objs:
if prior_obj in self.sampler_dict:
no_sample_objs.add(obj_id)
break

return [obj_id for obj_id in self.sampler_dict.keys() if obj_id not in no_sample_objs]

def UCB(self):
Expand Down Expand Up @@ -218,9 +220,15 @@ def sampling_function(
region.grid_size, region.rng,
obj_id=sample_data.obj_id,
obj_ids=sample_data.obj_ids,
obj_poses_pix=sampled_obj_poses_pix)
obj_poses_pix=sampled_obj_poses_pix,
sample_info = sample_data.sample_info
)
# cv2.imshow("prior", prior)
# cv2.waitKey(0)
# the prior object is too close to the boundary so that no sampling is possible
if np.sum(prior) <= 0:
obs = self.rng.choice([obj for obj in sample_data.obj_ids if obj != obj_id])
return False, obs, (obj_id, None)
# sample
valid_pose, _, samples_status, _ = region.sample(sample_data.obj_id, 1, prior,allow_outside=False)
if valid_pose.shape[0] > 0:
Expand Down
2 changes: 1 addition & 1 deletion lgmcts/components/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ def gen_prior(cls, img_size, rng, **kwargs):
return prior, {}

# parse spatial label
spatial_label = sample_info["spatial_label"] # [left, right, front, back]
spatial_label = list(sample_info["spatial_label"]) # [left, right, front, back]
if spatial_label == [1, 0, 0, 0]:
# left
prior[:, int(anchor[0]):] = 1.0
Expand Down
3 changes: 2 additions & 1 deletion lgmcts/scripts/eval/eval_lgmcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def eval_offline(dataset_path: str, method: str, mask_mode: str, n_samples: int

checkpoint_list = list(filter(lambda f: f.endswith(".pkl"), os.listdir(dataset_path)))
checkpoint_list.sort()
for i in range(min(n_epoches, len(checkpoint_list))):
# for i in range(min(n_epoches, len(checkpoint_list))):
for i in range(n_epoches):
print(f"==== Episode {i} ====")
## Step 1. init the env from dataset
env.reset()
Expand Down

0 comments on commit 61cf90b

Please sign in to comment.