Skip to content

Commit

Permalink
Fixes for online execution
Browse files Browse the repository at this point in the history
  • Loading branch information
mmattamala committed Jan 27, 2024
1 parent c5ccdb3 commit b6ed9a3
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 120 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ assets/stego/cocostuff27_vit_base_5.ckpt
.pytest_cache/**
.vscode/**
notebooks/**
*.code-workspace

# DrawIO
**.dtmp
**.bkp
Expand Down Expand Up @@ -143,4 +145,5 @@ dmypy.json

# Pyre type checker
.pyre/
assets/virutal_env/*
assets/virtual_env/*

2 changes: 1 addition & 1 deletion wild_visual_navigation/cfg/experiment_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class AblationDataModuleParams:

@dataclass
class ModelParams:
name: str = "LinearRnvp" # LinearRnvp, SimpleMLP, SimpleGCN, DoubleMLP
name: str = "SimpleMLP" # LinearRnvp, SimpleMLP, SimpleGCN, DoubleMLP
load_ckpt: Optional[str] = None

@dataclass
Expand Down
113 changes: 87 additions & 26 deletions wild_visual_navigation/feature_extractor/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,36 +31,50 @@ def __init__(
self._segmentation_type = segmentation_type
self._feature_type = feature_type
self._input_size = input_size

# Prepare segment extractor
self.segment_extractor = SegmentExtractor().to(self._device)

# Prepare extractor depending on the type
if self._feature_type == "stego":
self._feature_dim = 90
self.extractor = StegoInterface(device=device, input_size=input_size)
elif self._feature_type == "dino":
self._feature_dim = 90
self._extractor = StegoInterface(
device=device,
input_size=input_size,
n_image_clusters=kwargs.get("n_image_clusters", 20),
run_clustering=kwargs.get("run_clustering", True),
run_crf=kwargs.get("run_crf", False),
)

self.extractor = DinoInterface(
elif "dino" in self._feature_type:
self._feature_dim = 90
self._extractor = DinoInterface(
device=device,
input_size=input_size,
patch_size=kwargs.get("patch_size", 8),
backbone=kwargs.get("backbone", "dino"),
dim=kwargs.get("dino_dim", 384),
)

elif self._feature_type == "sift":
self._feature_dim = 128
self.extractor = DenseSIFTDescriptor().to(device)
self._extractor = DenseSIFTDescriptor().to(device)

elif self._feature_type == "torchvision":
self._extractor = TorchVisionInterface(
device=device, model_type=kwargs["model_type"], input_size=input_size
)

elif self._feature_type == "histogram":
self._feature_dim = 90

elif self._feature_type == "none":
pass

else:
raise f"Extractor[{self._feature_type}] not supported!"

# Segmentation
if self.segmentation_type == "slic":
from fast_slic import Slic

Expand All @@ -87,7 +101,7 @@ def extract(self, img, **kwargs):
if kwargs.get("return_dense_features", False):
return None, feat, seg, None, dense_feat

return None, feat, seg, None
return None, feat, seg, None, None

# Compute segments, their centers, and edges connecting them (graph structure)
# with Timer("feature_extractor - compute_segments"):
Expand All @@ -104,7 +118,7 @@ def extract(self, img, **kwargs):
if kwargs.get("return_dense_features", False):
return edges, feat, seg, center, dense_feat

return edges, feat, seg, center
return edges, feat, seg, center, None

@property
def feature_type(self):
Expand All @@ -125,7 +139,7 @@ def change_device(self, device):
device (str): new device
"""
self._device = device
self.extractor.change_device(device)
self._extractor.change_device(device)

def compute_segments(self, img: torch.tensor, **kwargs):
if self._segmentation_type == "none" or self._segmentation_type is None:
Expand All @@ -149,9 +163,9 @@ def compute_segments(self, img: torch.tensor, **kwargs):
# Compute edges and centers
if self._segmentation_type != "none" and self._segmentation_type is not None:
# Extract adjacency_list based on segments
edges = self.segment_extractor.adjacency_list(seg[None, None])
edges = self.segment_extractor.adjacency_list(seg)
# Extract centers
centers = self.segment_extractor.centers(seg[None, None])
centers = self.segment_extractor.centers(seg)

return edges.T, seg, centers

Expand Down Expand Up @@ -187,20 +201,20 @@ def segment_grid(self, img, **kwargs):
for i in range(patches.shape[1]):
patches[:, i, :, :, :] = i

combine_patch_size = (int(H / cell_size), int(W / cell_size))
# combine_patch_size = (int(H / cell_size), int(W / cell_size))
seg = combine_tensor_patches(
patches=patches,
original_size=(H, W),
window_size=combine_patch_size,
stride=combine_patch_size,
window_size=patch_size,
stride=patch_size,
)

return seg[0, 0].to(self._device)
return seg.to(self._device)

def segment_slic(self, img, **kwargs):
# Get slic clusters
img_np = kornia.utils.tensor_to_image(img)
seg = self.slic.iterate(np.uint8(np.ascontiguousarray(img_np) * 255))
seg = self.slic.iterate(np.uint8(np.ascontiguousarray(img_np) * 255))[None, None]
return torch.from_numpy(seg).to(self._device).type(torch.long)

def segment_random(self, img, **kwargs):
Expand All @@ -210,19 +224,19 @@ def segment_random(self, img, **kwargs):
seg = torch.full((H * W,), -1, dtype=torch.long, device=self._device)
indices = torch.randperm(H * W, device=self._device)[:nr]
seg[indices] = torch.arange(0, nr, device=self._device)
seg = seg.reshape(H, W)
seg = seg.reshape(H, W)[None, None]
return seg

def segment_stego(self, img, **kwargs):
# Prepare input image
img_internal = img.clone()
self.extractor.inference(img_internal)
seg = torch.from_numpy(self.extractor.cluster_segments).to(self._device)
self._extractor.inference(img_internal)
seg = self._extractor.cluster_segments.to(self._device)
# seg = torch.from_numpy(self._extractor.cluster_segments).to(self._device)

# Change the segment indices by numbers from 0 to N
for i, k in enumerate(seg.unique()):
seg[seg == k.item()] = i

return seg

def compute_features(self, img: torch.tensor, seg: torch.tensor, center: torch.tensor, **kwargs):
Expand Down Expand Up @@ -254,18 +268,18 @@ def compute_histogram(self, img: torch.tensor, seg: torch.tensor, **kwargs):
def compute_sift(self, img: torch.tensor, seg: torch.tensor, center: torch.tensor, **kwargs):
B, C, H, W = img.shape
if C == 3:
feat_r = self.extractor(img[:, 0, :, :][None])
feat_g = self.extractor(img[:, 1, :, :][None])
feat_b = self.extractor(img[:, 2, :, :][None])
feat_r = self._extractor(img[:, 0, :, :][None])
feat_g = self._extractor(img[:, 1, :, :][None])
feat_b = self._extractor(img[:, 2, :, :][None])
features = torch.cat([feat_r, feat_g, feat_b], dim=1)
else:
features = self.extractor(img)
features = self._extractor(img)
return features

@torch.no_grad()
def compute_dino(self, img: torch.tensor, seg: torch.tensor, center: torch.tensor, **kwargs):
img_internal = img.clone()
features = self.extractor.inference(img_internal)
features = self._extractor.inference(img_internal)
return features

@torch.no_grad()
Expand All @@ -276,7 +290,7 @@ def compute_torchvision(self, img: torch.tensor, seg: torch.tensor, center: torc

@torch.no_grad()
def compute_stego(self, img: torch.tensor, seg: torch.tensor, center: torch.tensor, **kwargs):
return self.extractor.features
return self._extractor.features

def sparsify_features(self, dense_features: torch.tensor, seg: torch.tensor, cumsum_trick=False):
if self._feature_type not in ["histogram"] and self._segmentation_type not in ["none"]:
Expand Down Expand Up @@ -361,9 +375,56 @@ def sparsify_features(self, dense_features: torch.tensor, seg: torch.tensor, cum
sparse_features = []
for i in range(seg.max() + 1):
m = seg == i
x, y = torch.where(m)
x, y = torch.where(m[0, 0])
feat = dense_features[0, :, x, y].mean(dim=1)
sparse_features.append(feat)
return torch.stack(sparse_features, dim=1).T
else:
return dense_features


def run_feature_extractor():
"""Tests feature extractor"""
import os
import cv2
from os.path import join
from pytictac import Timer
from torchvision import transforms as T
from wild_visual_navigation import WVN_ROOT_DIR

# Create test directory
os.makedirs(join(WVN_ROOT_DIR, "results", "test_feature_extractor"), exist_ok=True)

# Inference model
device = "cuda" if torch.cuda.is_available() else "cpu"

p = join(WVN_ROOT_DIR, "assets/images/forest_clean.png")
np_img = cv2.imread(p)
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(np_img).to(device)
img = img.permute(2, 0, 1)
img = (img.type(torch.float32) / 255)[None]
transform = T.Compose(
[
T.Resize(448, T.InterpolationMode.NEAREST),
T.CenterCrop(448),
]
)
img = transform(img)

# create feature extractor
fe = FeatureExtractor(device=device, segmentation_type="slic", feature_type="dino")
with Timer(f"SLIC-DINO"):
edges, feat, seg, center, dense_feat = fe.extract(img)

fe = FeatureExtractor(device=device, segmentation_type="grid", feature_type="dino")
with Timer(f"GRID-DINO"):
edges, feat, seg, center, dense_feat = fe.extract(img)

fe = FeatureExtractor(device=device, segmentation_type="stego", feature_type="stego")
with Timer(f"STEGO-STEGO"):
edges, feat, seg, center, dense_feat = fe.extract(img)


if __name__ == "__main__":
run_feature_extractor()
6 changes: 3 additions & 3 deletions wild_visual_navigation/feature_extractor/stego_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def inference(self, img: torch.tensor):
self._cluster_pred = F.interpolate(self._cluster_pred[None].float(), new_features_size, mode="nearest").int()
self._linear_pred = F.interpolate(self._linear_pred[None].float(), new_features_size, mode="nearest").int()

return self._linear_pred[0], self._cluster_pred[0]
return self._linear_pred, self._cluster_pred

@property
def model(self):
Expand Down Expand Up @@ -165,9 +165,9 @@ def run_stego_interfacer():

ax[0].imshow(img[0].permute(1, 2, 0).cpu().numpy())
ax[0].set_title("Image")
ax[1].imshow(si.cmap[cluster_pred[0].cpu() % si.cmap.shape[0]])
ax[1].imshow(si.cmap[cluster_pred[0, 0].cpu() % si.cmap.shape[0]])
ax[1].set_title("Cluster Predictions")
ax[2].imshow(si.cmap[linear_pred[0].cpu()])
ax[2].imshow(si.cmap[linear_pred[0, 0].cpu()])
ax[2].set_title("Linear Probe Predictions")
remove_axes(ax)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ image_graph_dist_thr: 0.2 # meters
supervision_graph_dist_thr: 0.1 # meters
network_input_image_height: 224 # 448
network_input_image_width: 224 # 448
segmentation_type: "slic"
feature_type: "dino" # TODO verify this here
segmentation_type: "stego"
feature_type: "stego" # TODO verify this here
dino_patch_size: 8 # 8 or 16; 8 is finer
slic_num_components: 100
dino_dim: 384 # 90 or 384; 384 is better
Expand All @@ -49,16 +49,14 @@ status_thread_rate: 0.5 # hertz

# Runtime options
device: "cuda"
mode: "debug" # check out comments in the class WVNMode
mode: "online" # check out comments in the class WVNMode
colormap: "RdYlBu"

print_image_callback_time: false
print_supervision_callback_time: false
log_time: false
log_confidence: false
verbose: false
debug_supervision_node_index_from_last: 10
use_debug_for_desired: false
verbose: true

extraction_store_folder: "nan"
exp: "nan"
Expand Down
Loading

0 comments on commit b6ed9a3

Please sign in to comment.