Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Sep 12, 2024
1 parent 4f9192d commit 6c23e01
Show file tree
Hide file tree
Showing 5 changed files with 1,095 additions and 3,143 deletions.
1 change: 1 addition & 0 deletions .ci/spellcheck/.pyspelling.wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ LVLM
Lyth
macOS
Magika
masklet
Mahalanobis
Mapillary
Markovian
Expand Down
39 changes: 39 additions & 0 deletions notebooks/segment-anything/gradio_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,42 @@ def get_select_coords(img, evt: gr.SelectData):
input_img.upload(on_image_change, [input_img], [input_img])

return demo


def make_video_demo(segmenter, sample_path):
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_video = gr.Video(label="Video")
coordinates = gr.Textbox(label="Coordinates")
labels = gr.Textbox(label="Labels")
submit_btn = gr.Button(value="Segment")
with gr.Column():
output_video = gr.Video(label="Output video")

def on_video_change(video):
segmenter.set_video(video)
return video

def segment_video(video, coordinates_txt, labels_txt):
coordinates_np = []
for coords in coordinates_txt.split(";"):
temp = [float(numb) for numb in coords.split(",")]
coordinates_np.append(temp)

labels_np = []
for l in labels_txt.split(","):
labels_np.append(int(l))
segmenter.set_video(video)
segmenter.add_new_points_or_box(coordinates_np, labels_np)
segmenter.propagate_in_video()
video_out_path = segmenter.save_as_video()

return video_out_path

submit_btn.click(segment_video, inputs=[input_video, coordinates, labels], outputs=[output_video])
input_video.upload(on_video_change, [input_video], [input_video])

examples = gr.Examples(examples=[[sample_path / "coco.mp4", "430, 130; 500, 100", "1, 1"]], inputs=[input_video, coordinates, labels])

return demo
25 changes: 14 additions & 11 deletions notebooks/segment-anything/model_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def __init__(
self,
ov_image_encoder,
ov_mask_encoder,
sam_predictor,
ov_memory_encoder,
ov_memory_attention_model,
memory_encoder_out_proj_weight_shape=None,
fill_hole_area=0,
# whether to apply non-overlapping constraints on the output object masks
non_overlap_masks=False,
Expand All @@ -45,17 +45,23 @@ def __init__(
self.ov_memory_encoder = ov_memory_encoder
self.ov_memory_attention_model = ov_memory_attention_model

if hasattr(sam_predictor.memory_encoder, "out_proj") and hasattr(sam_predictor.memory_encoder.out_proj, "weight"):
self.mem_dim = sam_predictor.memory_encoder.out_proj.weight.shape[0]
if not memory_encoder_out_proj_weight_shape is None:
self.mem_dim = memory_encoder_out_proj_weight_shape

# Temporal encoding of the memories
self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(sam_predictor.num_maskmem, 1, 1, self.mem_dim))
self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(self.num_maskmem, 1, 1, self.mem_dim))

@classmethod
def from_pretrained(
cls, predictor_video, model_info, ov_image_encoder, ov_mask_encoder, ov_memory_encoder, ov_memory_attention_model, apply_postprocessing=True
cls,
model_info,
ov_image_encoder,
ov_mask_encoder,
ov_memory_encoder,
ov_memory_attention_model,
memory_encoder_out_proj_weight_shape=None,
apply_postprocessing=True,
):
memory_attention = predictor_video.memory_attention

v_inputs = {
"sigmoid_scale_for_mem_enc": model_info["model"]["sigmoid_scale_for_mem_enc"],
Expand All @@ -80,7 +86,7 @@ def from_pretrained(
"iou_prediction_use_sigmoid": model_info["model"]["iou_prediction_use_sigmoid"],
"compile_image_encoder": False,
"image_encoder": ov_image_encoder,
"memory_attention": memory_attention,
"memory_attention": ov_memory_attention_model,
"memory_encoder": ov_memory_encoder,
}

Expand All @@ -91,11 +97,9 @@ def from_pretrained(
return cls(
ov_image_encoder=ov_image_encoder,
ov_mask_encoder=ov_mask_encoder,
sam_predictor=predictor_video,
ov_memory_encoder=ov_memory_encoder,
ov_memory_attention_model=ov_memory_attention_model,
# memory_attention=memory_attention,
# memory_encoder_out_proj_weight=out_proj_weight,
memory_encoder_out_proj_weight_shape=memory_encoder_out_proj_weight_shape,
**v_inputs,
)

Expand Down Expand Up @@ -211,7 +215,6 @@ def _prepare_memory_conditioned_features(
feats = prev["maskmem_features"].to(device, non_blocking=True)
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
# Spatial positional encoding (it might have been offloaded to CPU in eval)
print(len(prev["maskmem_pos_enc"]), prev["maskmem_pos_enc"][-1].shape, prev["maskmem_pos_enc"][0].shape)
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
# Temporal positional encoding
Expand Down
Loading

0 comments on commit 6c23e01

Please sign in to comment.