Skip to content

Commit

Permalink
update: add force anchor exclude
Browse files Browse the repository at this point in the history
  • Loading branch information
changhaonan committed Sep 11, 2023
1 parent 8f85dc6 commit ecf4919
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 47 deletions.
18 changes: 9 additions & 9 deletions lgmcts/components/obj_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ def set_objs(self, obj_list: list[ObjEntry], texture_list: list[TextureEntry]):
for id, (obj_entry, texture_entry) in enumerate(zip(obj_list, texture_list)):
self.obj_bag_list.append(get_object_bag(id, obj_entry.value, texture_entry.value))

def select_obj(self, anchor_obj_bag: ObjectBag, attribute: str, compare_rel: CompareRel):
def select_obj(self, anchor_obj_bag: ObjectBag, attribute: str, compare_rel: CompareRel, force_anchor_exclude: bool = False):
"""Select object based on attribute
"""
assert attribute in COMPARE_DICT, f"Attribute {attribute} not supported"
in_obj, in_color, in_size, out_obj, out_color, out_size = [], [], [], [], [], []
# check self-include
self_include = False
if compare_rel == EqualRel():
in_obj.append(self.obj_list[anchor_obj_bag.obj_id])
in_color.append(self.texture_list[anchor_obj_bag.obj_id])
# selected_size.append(self.size_list[anchor_obj_bag.size_id])
self_include = True
if not force_anchor_exclude:
if compare_rel == EqualRel():
in_obj.append(self.obj_list[anchor_obj_bag.obj_id])
in_color.append(self.texture_list[anchor_obj_bag.obj_id])
self_include = True
for obj_bag in self.obj_bag_list:
if compare_rel == COMPARE_DICT[attribute](obj_bag, anchor_obj_bag):
if anchor_obj_bag.obj_id == obj_bag.obj_id:
Expand All @@ -63,7 +63,7 @@ def select_obj(self, anchor_obj_bag: ObjectBag, attribute: str, compare_rel: Com
# out_size.append(self.size_list[obj_bag.size_id])
return self_include, in_obj, in_color, in_size, out_obj, out_color, out_size

def gen_anchor_obj_prompt(self):
def gen_anchor_obj_prompt(self, force_anchor_exclude: bool = False):
"""Based on the obj we have, generate a valid anchor obj prompt"""
# random select anchor
# FIXME: Rewrite the logic for anchor selection
Expand All @@ -83,7 +83,7 @@ def gen_anchor_obj_prompt(self):
prompt_str = f"objects whose {attribute} {compare_rel_str} {anchor_obj}"

# select objects
self_include, in_obj, in_color, in_size, out_obj, out_color, out_size = self.select_obj(anchor_obj_bag, attribute, compare_rel)
self_include, in_obj, in_color, in_size, out_obj, out_color, out_size = self.select_obj(anchor_obj_bag, attribute, compare_rel, force_anchor_exclude)
if len(in_obj) >= 3: # at least 3 objects to formulate a pattern
if not self_include:
anchor_obj = self.obj_list[anchor_obj_bag.obj_id]
Expand All @@ -107,4 +107,4 @@ def gen_anchor_obj_prompt(self):
}
# warnings.warn("Cannot generate a valid prompt")
assert False, "Cannot generate a valid prompt"
return {"anchor_obj": None, "in_obj": [], "in_color": [], "in_size": [], "out_obj": [], "out_color": [], "out_size": []}
return {"anchor_obj": None, "in_obj": [], "in_color": [], "in_size": [], "out_obj": [], "out_color": [], "out_size": []}
2 changes: 1 addition & 1 deletion lgmcts/scripts/data_generation/gen_strdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _generate_data_for_one_task(
obj_selector.reset()

# generate goal
prompt_str, obs = task.gen_goal_config(env, prompt_generator, obj_selector, enable_distract=False)
prompt_str, obs = task.gen_goal_config(env, prompt_generator, obj_selector, enable_distract=False, force_anchor_exclude=True)
obs_cache.append(obs)

# generate start
Expand Down
76 changes: 39 additions & 37 deletions lgmcts/tasks/struct_rearrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ class ResultTuple(NamedTuple):
failure: bool
distance: float | None


class StructRearrange(BaseTask):
"""Structured Rearrange Task"""
task_name = "struct_rearrange"

def __init__(
self,
self,
# ==== task specific ====
max_num_obj: int = 10,
max_num_pattern: int = 2,
Expand All @@ -40,8 +40,8 @@ def __init__(
obs_img_views: str | list[str] | None = None,
obs_img_size: tuple[int, int] = (128, 256),
seed: int | None = None,
debug: bool = False,):
debug: bool = False,):

super().__init__(
modalities=["rgb"],
obs_img_views=obs_img_views,
Expand All @@ -55,7 +55,7 @@ def __init__(
self.pattern_types = pattern_types
self.obj_list = [ObjPedia.lookup_object_by_name(obj) for obj in obj_list]
self.color_list = [TexturePedia.lookup_color_by_name(color) for color in color_list]
#
#
self.obs_img_size = obs_img_size
# temporary data
self.goal_pattern_info = {}
Expand All @@ -70,7 +70,7 @@ def reset(self, env):
self.distract_obj_ids = []

def add_objects_to_pattern(
self, env, objs, colors, pattern_prior: np.ndarray | None, num_limit: list[int]=[0, 100], use_existing: bool=False, stack_prob: float=0.0):
self, env, objs, colors, pattern_prior: np.ndarray | None, num_limit: list[int] = [0, 100], use_existing: bool = False, stack_prob: float = 0.0):
"""Set objects to a line, use_existing decides whether to add new object or not"""
# Add object
added_obj_ids = []
Expand All @@ -97,7 +97,7 @@ def add_objects_to_pattern(
# assert False, "No object is added to the pattern"
return added_obj_ids, obj_status

def add_objects_to_random(self, env, max_num_obj: int, obj_candidates: list=[], color_candidates: list=[], use_existing: bool=False, stack_prob :float=0.0):
def add_objects_to_random(self, env, max_num_obj: int, obj_candidates: list = [], color_candidates: list = [], use_existing: bool = False, stack_prob: float = 0.0):
"""Set objects to random positions
Args:
max_num_obj: maximum number of objects to add
Expand Down Expand Up @@ -131,17 +131,17 @@ def add_objects_to_random(self, env, max_num_obj: int, obj_candidates: list=[],

def gen_goal_spec(self, env):
"""goal specification; used for StructDiffusion"""
#FIXME: what does these contents mean?
# FIXME: what does these contents mean?
spec = super().gen_goal_spec(env)
goal = self.goals[0]
# anchor object
spec["anchor"] = {
"objects": [],
"features" : [
"features": [
{
"comparator" : None,
"type" : "color_d",
"value" : env.obj_id_reverse_mapping[goal["anchor_id"]]['texture_name']
"comparator": None,
"type": "color_d",
"value": env.obj_id_reverse_mapping[goal["anchor_id"]]['texture_name']
}
]
}
Expand All @@ -157,11 +157,11 @@ def gen_goal_spec(self, env):
"combine_features_logic": "None",
"count": "None",
"objects": [],
"features" : [
"features": [
{
"comparator" : None,
"type" : "color_d",
"value" : env.obj_id_reverse_mapping[random.choice(goal["obj_ids"])]['texture_name']
"comparator": None,
"type": "color_d",
"value": env.obj_id_reverse_mapping[random.choice(goal["obj_ids"])]['texture_name']
}
]
}
Expand Down Expand Up @@ -204,36 +204,37 @@ def gen_type_vocabs(self):
type_vocabs["class"][obj.name] = i
type_vocabs["color"] = {}
for i, color in enumerate(self.color_list):

type_vocabs["color"][color.name] = i
print("type_vocabs:", type_vocabs)
return type_vocabs

def gen_goal_config(self, env, promptor: PromptGenerator, obj_selector: ObjectSelector, **kwargs):
"""Generate goal config"""
num_color = kwargs.get("num_color", 2) # Each scene only has X colors
force_anchor_exclude = kwargs.get("force_anchor_exclude", False)
num_added_objs = 0
obj_list = self.rng.choice(self.obj_list, min(self.max_num_obj, len(self.obj_list)), replace=False) # current candidate
color_list = self.rng.choice(self.color_list, num_color, replace=False)
## Step 1: select object candidates
# Step 1: select object candidates
for i in range(max(self.max_num_pattern - 1, 1)):
if obj_list is None or len(obj_list) <= 2:
break # no more enough candidate to formulate pattern
selected_objs = obj_list
selected_colors = self.rng.choice(color_list, len(obj_list), replace=True)
obj_selector.reset()
obj_selector.set_objs(selected_objs, selected_colors)
selection = obj_selector.gen_anchor_obj_prompt()
selection = obj_selector.gen_anchor_obj_prompt(force_anchor_exclude=force_anchor_exclude)
if not selection: # no valid selection
continue
## Step 2: select pattern & add objects to scene
# Step 2: select pattern & add objects to scene
if selection["anchor_obj"] is not None:
[anchor_id], _ = self.add_objects_to_pattern(
env,
objs=[selection["anchor_obj"]],
colors=[selection["anchor_color"]],
env,
objs=[selection["anchor_obj"]],
colors=[selection["anchor_color"]],
pattern_prior=None,
use_existing=False,
use_existing=False,
stack_prob=0.0) # add anchor object
else:
anchor_id = -1
Expand All @@ -246,12 +247,12 @@ def gen_goal_config(self, env, promptor: PromptGenerator, obj_selector: ObjectSe
pattern_prior, pattern_info = PATTERN_DICT[pattern_type].gen_prior(env.ws_map_size, env.rng)
num_limit = PATTERN_DICT[pattern_type]._num_limit
rearrange_obj_ids, obj_status = self.add_objects_to_pattern(
env,
objs=selection["in_obj"],
colors=selection["in_color"],
pattern_prior=pattern_prior,
env,
objs=selection["in_obj"],
colors=selection["in_color"],
pattern_prior=pattern_prior,
num_limit=num_limit,
use_existing=False,
use_existing=False,
stack_prob=0.0)
if len(rearrange_obj_ids) == 0:
continue
Expand All @@ -269,7 +270,7 @@ def gen_goal_config(self, env, promptor: PromptGenerator, obj_selector: ObjectSe
obj_list = selection["out_obj"]
color_list = selection["out_color"]

## Step 3: add distract objects
# Step 3: add distract objects
enable_distract = kwargs.get("enable_distract", True)
if enable_distract:
num_distract = self.max_num_obj - num_added_objs - 1
Expand All @@ -278,15 +279,16 @@ def gen_goal_config(self, env, promptor: PromptGenerator, obj_selector: ObjectSe
else:
self.distract_obj_ids = []
num_distract = len(self.distract_obj_ids)
## Step 4:
# Step 4:
if self.max_num_pattern > 1:
if len(self.goals) < self.max_num_pattern and num_distract > 0: # not enough pattern
# 4.1 add a spatial prompt
# randomly select one from pattern obj and added obj
anchor_id = env.rng.choice(rearrange_obj_ids)
place_id = env.rng.choice(self.distract_obj_ids)
pair_obj_ids = [anchor_id, place_id]
pair_obj_names = [f"{env.obj_id_reverse_mapping[obj_id]['texture_name']} {env.obj_id_reverse_mapping[obj_id]['obj_name']}" for obj_id in pair_obj_ids]
pair_obj_names = [
f"{env.obj_id_reverse_mapping[obj_id]['texture_name']} {env.obj_id_reverse_mapping[obj_id]['obj_name']}" for obj_id in pair_obj_ids]
# compute spatial from the pair
aabb_1 = pybullet_utils.get_obj_aabb(env, pair_obj_ids[0])
aabb_2 = pybullet_utils.get_obj_aabb(env, pair_obj_ids[1])
Expand All @@ -307,9 +309,9 @@ def gen_goal_config(self, env, promptor: PromptGenerator, obj_selector: ObjectSe
"spatial_label": spatial_label,
"spatial_str": spatial_rel
}
)
)

## Step 5: assemble prompt and goal specific
# Step 5: assemble prompt and goal specific
# gen prompt
promptor.gen_prompt()
self.prompt = promptor.prompt
Expand All @@ -318,7 +320,7 @@ def gen_goal_config(self, env, promptor: PromptGenerator, obj_selector: ObjectSe
return self.prompt, obs

# def gen_goal_config(self, env, promptor: PromptGenerator, obj_selector: ObjectSelector) -> tuple[str, dict]:
# """Generate goal config. The goal of object selector is selection a set of object,
# """Generate goal config. The goal of object selector is selection a set of object,
# with some specific pattern.
# """
# #TODO: add region prompt later, currently not supported
Expand All @@ -345,7 +347,7 @@ def gen_goal_config(self, env, promptor: PromptGenerator, obj_selector: ObjectSe
# break
# except:
# continue

# promptor.gen_pattern_prompt(obj_str, pattern_type)
# # update goals
# pattern_info["obj_ids"] = pattern_obj_ids
Expand Down Expand Up @@ -404,4 +406,4 @@ def check_success(self, *args, **kwargs) -> ResultTuple:
return ResultTuple(success=False, failure=True, distance=None)
else:
warnings.warn(f"Pattern type {pattern_type} is not supported")
return ResultTuple(success=True, failure=False, distance=None)
return ResultTuple(success=True, failure=False, distance=None)

0 comments on commit ecf4919

Please sign in to comment.