Skip to content

Commit

Permalink
update: add a struct generation for structformer
Browse files Browse the repository at this point in the history
  • Loading branch information
changhaonan committed Sep 12, 2023
1 parent ecf4919 commit 5abe90d
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 96 deletions.
47 changes: 44 additions & 3 deletions lgmcts/components/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
PATTERN_CONSTANTS = {
"line": {
"line_len": {
"L": [0.4, 0.45],
"M": [0.3, 0.4],
"S": [0.2, 0.3]
"L": [0.2, 0.2],
"M": [0.1, 0.1],
"S": [0.05, 0.05]
}
},
"circle": {
Expand Down Expand Up @@ -48,6 +48,14 @@ def gen_prior(cls, size, rng, **kwargs):
"""
raise NotImplementedError

@abstractclassmethod
def gen_ordered_prior(cls, size, rng, **kwargs):
"""Generate a fixed pattern prior:
Args:
rng: random generator
"""
raise NotImplementedError

@abstractclassmethod
def check(cls, obj_poses: dict[int, np.ndarray], **kwargs):
"""Check if the object states meet the pattern requirement
Expand Down Expand Up @@ -143,6 +151,39 @@ def gen_prior(cls, img_size, rng, **kwargs):
pattern_info["rotation"] = [0.0, 0.0, angle]
return prior, pattern_info

@classmethod
def gen_ordered_prior(cls, img_size, rng, **kwargs):
obj_id = kwargs.get("obj_id", -1)
obj_ids = kwargs.get("obj_ids", [])
thickness = kwargs.get("thickness", 1)
assert len(obj_ids) == 0 or (len(obj_ids) >= cls._num_limit[0] and len(obj_ids)
<= cls._num_limit[1]), "Number of objects should be within the limit!"

# extract relative obj & poses
obj_idx_in_list = obj_ids.index(obj_id)
assert obj_idx_in_list >= 0, "Object id not found!"
# some constants
scale = kwargs.get("scale", 0.1)

position = kwargs.get("position", [0.0, 0.0])
angle = kwargs.get("angle", 0.0)

height, width = img_size[0], img_size[1]
prior = np.zeros([height, width], dtype=np.float32)

x0 = int((position[0] + scale * math.sin(angle) * obj_idx_in_list) * width)
y0 = int((position[1] + scale * math.cos(angle) * obj_idx_in_list) * height)
cv2.circle(prior, (x0, y0), thickness, 1.0, -1)
pattern_info = {
"type": "pattern:line",
"min_length": scale,
"max_length": scale,
"length": scale,
"position": position.tolist() + [0.0],
"rotation": [0.0, 0.0, angle]
}
return prior, pattern_info

@classmethod
def check(cls, obj_poses: dict[int, np.ndarray], **kwargs):
"""Check if obj poses meets a line pattern"""
Expand Down
54 changes: 27 additions & 27 deletions lgmcts/scripts/data_generation/gen_strdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _generate_data_for_one_task(
):
# prepare path
save_path = U.f_join(save_path, task_name)
save_path = os.path.join(save_path, "circle/result")
save_path = os.path.join(save_path, "line/result")
os.makedirs(save_path, exist_ok=True)
os.makedirs(os.path.join(save_path, "batch300"), exist_ok=True)
os.makedirs(os.path.join(save_path, "index"), exist_ok=True)
Expand All @@ -62,32 +62,32 @@ def _generate_data_for_one_task(
export_file_list = []

while True:
try:
env.set_seed(seed + n_generated)
num_tried_this_seed += 1
obs_cache = []

step_t = 0
# reset
env.reset()
prompt_generator.reset()
obj_selector.reset()

# generate goal
prompt_str, obs = task.gen_goal_config(env, prompt_generator, obj_selector, enable_distract=False, force_anchor_exclude=True)
obs_cache.append(obs)

# generate start
obs = task.gen_start_config(env)
goal_spec = task.gen_goal_spec(env)
obs_cache.append(obs)

step_t += 1
except Exception as e:
print('strdiff exception:', e)
seed += 1
num_tried_this_seed = 0
continue
# try:
env.set_seed(seed + n_generated)
num_tried_this_seed += 1
obs_cache = []

step_t = 0
# reset
env.reset()
prompt_generator.reset()
obj_selector.reset()

# generate goal
prompt_str, obs = task.gen_goal_config_ordered(env, prompt_generator, obj_selector, enable_distract=False, force_anchor_exclude=True)
obs_cache.append(obs)

# generate start
obs = task.gen_start_config(env)
goal_spec = task.gen_goal_spec(env)
obs_cache.append(obs)

step_t += 1
# except Exception as e:
# print('strdiff exception:', e)
# seed += 1
# num_tried_this_seed = 0
# continue

# Process output data
obs = U.stack_sequence_fields(obs_cache)
Expand Down
143 changes: 77 additions & 66 deletions lgmcts/tasks/struct_rearrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def add_objects_to_pattern(
raise NotImplementedError("Not implemented yet")
if len(added_obj_ids) == 0:
if pattern_prior is not None:
cv2.imshow("prior", pattern_prior)
cv2.waitKey(0)
warnings.warn("No object is added to the pattern")
# assert False, "No object is added to the pattern"
return added_obj_ids, obj_status

Expand Down Expand Up @@ -204,8 +203,7 @@ def gen_type_vocabs(self):
type_vocabs["class"][obj.name] = i
type_vocabs["color"] = {}
for i, color in enumerate(self.color_list):

type_vocabs["color"][color.name] = i
type_vocabs["color"][color.name.lower().replace("_", " ")] = i
print("type_vocabs:", type_vocabs)
return type_vocabs

Expand Down Expand Up @@ -319,71 +317,84 @@ def gen_goal_config(self, env, promptor: PromptGenerator, obj_selector: ObjectSe
obs, _, _, _, _ = env.step()
return self.prompt, obs

# def gen_goal_config(self, env, promptor: PromptGenerator, obj_selector: ObjectSelector) -> tuple[str, dict]:
# """Generate goal config. The goal of object selector is selection a set of object,
# with some specific pattern.
# """
# #TODO: add region prompt later, currently not supported
# max_try = 3 # max try for sampling
# obj_selector.reset()
# obj_selector.set_objs(env.obj_ids["rigid"], env.obj_list, env.color_list)

# # Step 1: select objects & colors
# num_color = 2 # Each scene only has X colors
# num_object = 6 # Each scene only has at most X objects
# selected_objs = self.rng.choice(env.obj_list, num_object, replace=False)
# color_candidates = self.rng.choice(env.color_list, num_color, replace=False)
# selected_colors = self.rng.choice(color_candidates, num_object, replace=True)
def gen_goal_config_ordered(self, env, promptor: PromptGenerator, obj_selector: ObjectSelector, **kwargs):
"""Generate goal config with fix prior"""
num_color = kwargs.get("num_color", 2) # Each scene only has X colors
force_anchor_exclude = kwargs.get("force_anchor_exclude", False)
num_added_objs = 0
obj_list = self.rng.choice(self.obj_list, min(self.max_num_obj, len(self.obj_list)), replace=False) # current candidate
color_list = self.rng.choice(self.color_list, num_color, replace=False)
# Step 1: select object candidates
for i in range(max(self.max_num_pattern - 1, 1)):
if obj_list is None or len(obj_list) <= 2:
break # no more enough candidate to formulate pattern
selected_objs = obj_list
selected_colors = self.rng.choice(color_list, len(obj_list), replace=True)
obj_selector.reset()
obj_selector.set_objs(selected_objs, selected_colors)
selection = obj_selector.gen_anchor_obj_prompt(force_anchor_exclude=force_anchor_exclude)
if not selection: # no valid selection
continue
# Step 2: select pattern & add objects to scene
if selection["anchor_obj"] is not None:
[anchor_id], _ = self.add_objects_to_pattern(
env,
objs=[selection["anchor_obj"]],
colors=[selection["anchor_color"]],
pattern_prior=None,
use_existing=False,
stack_prob=0.0) # add anchor object
else:
anchor_id = -1
# generate pattern
pattern_type = env.rng.choice(self.pattern_types)
max_try = 3
rearrange_obj_ids = []
pattern_info = {}

# ## Step 2: generate a random pattern & add objects to the scene
# pattern_type = env.rng.choice(self.pattern_types)
# # max_num_pattern = int(self.max_num_obj/2)
# max_num_pattern = 3
# for i in range(max_try):
# try:
# pattern_prior, pattern_info = PATTERN_DICT[pattern_type].gen_prior(env.ws_map_size, env.rng)
# pattern_obj_ids = self.add_objects_to_pattern(env, selected_objs, selected_colors, pattern_prior, False, self.stack_prob)
# assert len(pattern_obj_ids) > 0, "No object is added to the pattern"
# break
# except:
# continue
for i in range(max_try):
# generate random position & rotation
scale = 0.2
position = self.rng.uniform(scale, 1.0-scale, size=(2,))
angle = np.pi / 3.0 * self.rng.random() - np.pi / 6.0 # [-pi/6, pi/6]
in_objs = selection["in_obj"]
for in_id, in_obj in enumerate(in_objs):
pattern_prior, pattern_info = PATTERN_DICT[pattern_type].gen_ordered_prior(
env.ws_map_size, env.rng, obj_id=in_id, obj_ids=list(range(len(in_objs))), position=position, angle=angle, scale=scale)
added_obj_ids, obj_status = self.add_objects_to_pattern(
env,
objs=[selection["in_obj"][in_id]],
colors=[selection["in_color"][in_id]],
pattern_prior=pattern_prior,
use_existing=False,
stack_prob=0.0)
if len(added_obj_ids) == 0:
break
rearrange_obj_ids += added_obj_ids
if len(rearrange_obj_ids) == 0:
continue
else:
break

# promptor.gen_pattern_prompt(obj_str, pattern_type)
# # update goals
# pattern_info["obj_ids"] = pattern_obj_ids
# self.goals.append(pattern_info)
if anchor_id == -1:
anchor_id = rearrange_obj_ids[0]
# update goals
pattern_info["obj_ids"] = rearrange_obj_ids
pattern_info["anchor_id"] = anchor_id
self.goals.append(pattern_info)
# update prompt
promptor.gen_pattern_prompt(selection["prompt_str"], pattern_type)
# update obj
num_added_objs += len(rearrange_obj_ids)
obj_list = selection["out_obj"]
color_list = selection["out_color"]

# ## Step 2: add some more objects & spatial relationship
# # max_num_add = int(self.max_num_obj/4)
# # max_num_add = 1 #FIXME: only add one object for now
# # added_obj_ids = self.add_objects_to_random(env, max_num_add, False, self.stack_prob)
# # # randomly select one from pattern obj and added obj
# # pair_obj_ids = env.rng.choice(pattern_obj_ids + added_obj_ids, 2)
# # pair_obj_names = [f"{env.obj_id_reverse_mapping[obj_id]['texture_name']} {env.obj_id_reverse_mapping[obj_id]['obj_name']}" for obj_id in pair_obj_ids]
# # # compute spatial from the pair
# # aabb_1 = pybullet_utils.get_obj_aabb(env, pair_obj_ids[0])
# # aabb_2 = pybullet_utils.get_obj_aabb(env, pair_obj_ids[1])
# # pose_1 = spatial_utils.Points9.from_aabb(aabb_1[0], aabb_1[1])
# # pose_2 = spatial_utils.Points9.from_aabb(aabb_2[0], aabb_2[1])
# # spatial_label = spatial_utils.Points9.label(pose_1, pose_2)
# # spatial_str_list = spatial_utils.Points9.vocabulary(spatial_label)
# # if spatial_str_list[0] != "A has no relationship with B":
# # spatial_rel = self.rng.choice(spatial_str_list)
# # prompt.gen_pair_prompt(pair_obj_names[0], pair_obj_names[1], spatial_rel[4:-1].strip())
# # # update goal
# # self.goals.append(
# # {
# # "type": "pattern:spatial",
# # "obj_ids": pair_obj_ids,
# # "spatial_label": spatial_label,
# # "spatial_str": spatial_rel
# # }
# # )
# # Env step forward
# obs, _, _, _, _ = env.step()
# #
# self.prompt = promptor.prompt
# return self.prompt, obs
# gen prompt
promptor.gen_prompt()
self.prompt = promptor.prompt
# Env step forward
obs, _, _, _, _ = env.step()
return self.prompt, obs

def gen_start_config(self, env) -> dict:
"""Generate a random config using existing objects"""
Expand Down

0 comments on commit 5abe90d

Please sign in to comment.