From afb5effc14a43d8cdac7fe38e4b75a7f7eb06e84 Mon Sep 17 00:00:00 2001 From: Matias Mattamala Date: Wed, 31 Jan 2024 22:04:20 +0000 Subject: [PATCH] Fix unit tests --- tests/test_confidence_generator.py | 15 +- tests/test_dino_time.py | 32 +- tests/test_feature_extractor.py | 82 +++-- tests/test_image_projector.py | 34 +- tests/test_kornia.py | 7 +- tests/test_monitoring.py | 18 +- tests/test_optical_flow.py | 31 +- .../cfg/experiment_params.py | 4 +- wild_visual_navigation/cfg/ros_params.py | 7 +- .../feature_extractor/dino_interface.py | 22 +- .../feature_extractor/dino_trt_interface.py | 16 +- .../feature_extractor/feature_extractor.py | 68 +--- .../feature_extractor/stego_interface.py | 29 +- .../torchvision_interface.py | 24 +- .../image_projector/image_projector.py | 24 +- wild_visual_navigation/model/linear_rnvp.py | 32 +- wild_visual_navigation/utils/__init__.py | 1 + .../utils/confidence_generator.py | 31 +- .../wild_visual_navigation/default.yaml | 8 +- .../scripts/overlay_images.py | 11 +- .../scripts/wvn_feature_extractor_node.py | 302 ++++++++++-------- .../scripts/wvn_learning_node.py | 71 ++-- 22 files changed, 411 insertions(+), 458 deletions(-) diff --git a/tests/test_confidence_generator.py b/tests/test_confidence_generator.py index beecd2da..e3c4a0bf 100644 --- a/tests/test_confidence_generator.py +++ b/tests/test_confidence_generator.py @@ -20,6 +20,13 @@ def generate_traversability_test_signal(N=500, T=10, events=[3, 8], event_length def test_confidence_generator(): + from wild_visual_navigation.utils.testing import make_results_folder + from wild_visual_navigation.visu import get_img_from_fig + from os.path import join + + # Create test directory + outpath = make_results_folder("test_confidence_generator") + device = torch.device("cpu") N = 1000 sigma_factor = 0.5 @@ -149,7 +156,13 @@ def test_confidence_generator(): axs[2].set_ylabel("Confidence") axs[2].legend(loc="upper right") - plt.show() + img = get_img_from_fig(fig) + img.save( + join( + outpath, + "confidence_generator_test.png", + ) + ) if __name__ == "__main__": diff --git a/tests/test_dino_time.py b/tests/test_dino_time.py index 26fa6743..ce6b34a0 100644 --- a/tests/test_dino_time.py +++ b/tests/test_dino_time.py @@ -1,16 +1,15 @@ -from wild_visual_navigation import WVN_ROOT_DIR +# from wild_visual_navigation import WVN_ROOT_DIR from pytictac import Timer from wild_visual_navigation.feature_extractor import ( DinoInterface, - DinoTrtInterface, # TrtModel, ) # from collections import namedtuple, OrderedDict # from torchvision import transforms as T -import cv2 -import os +# import os import torch +from wild_visual_navigation.utils.testing import load_test_image, get_dino_transform # import tensorrt as trt # import numpy as np @@ -19,23 +18,21 @@ def test_dino_interfacer(): device = "cuda" if torch.cuda.is_available() else "cpu" di = DinoInterface(device) + transform = get_dino_transform() - np_img = cv2.imread(os.path.join(WVN_ROOT_DIR, "assets/images/forest_clean.png")) - img = torch.from_numpy(cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)).to(device) - img = img.permute(2, 0, 1) - img = (img.type(torch.float32) / 255)[None] + img = load_test_image().to(device) ##################################################################################### for i in range(5): im = img + torch.rand(img.shape, device=img.device) / 100 - di.inference(di.transform(im)) + di.inference(transform(im)) ##################################################################################### with Timer("BS1 Dino Inference: "): for i in range(5): im = img + torch.rand(img.shape, device=img.device) / 100 with Timer("BS1 Dino Single: "): - di.inference(di.transform(im)) + di.inference(transform(im)) ##################################################################################### # img = img.repeat(4, 1, 1, 1) @@ -46,14 +43,15 @@ def test_dino_interfacer(): # res = di.inference(di.transform(im)) ##################################################################################### - # Conversion from ONNX model (https://github.com/facebookresearch/dino) - exported_trt_file = "dino_exported.trt" - exported_trt_path = os.path.join(WVN_ROOT_DIR, "assets/dino", exported_trt_file) - di_trt = DinoTrtInterface(exported_trt_path, device) + # # Conversion from ONNX model (https://github.com/facebookresearch/dino) + # from wild_visual_navigation.feature_extractor import DinoTrtInterface + # exported_trt_file = "dino_exported.trt" + # exported_trt_path = os.path.join(WVN_ROOT_DIR, "assets/dino", exported_trt_file) + # di_trt = DinoTrtInterface(exported_trt_path, device) - with Timer("TensorRT Inference: "): - im = img + torch.rand(img.shape, device=img.device) / 100 - di_trt.inference(di.transform(im).contiguous()) + # with Timer("TensorRT Inference: "): + # im = img + torch.rand(img.shape, device=img.device) / 100 + # di_trt.inference(di.transform(im).contiguous()) ##################################################################################### # Conversion using the torch_tensorrt library: https://github.com/pytorch/TensorRT diff --git a/tests/test_feature_extractor.py b/tests/test_feature_extractor.py index 2266256a..37b7f6e9 100644 --- a/tests/test_feature_extractor.py +++ b/tests/test_feature_extractor.py @@ -1,53 +1,51 @@ -from wild_visual_navigation import WVN_ROOT_DIR from wild_visual_navigation.feature_extractor import FeatureExtractor from wild_visual_navigation.visu import get_img_from_fig +from wild_visual_navigation.utils.testing import load_test_image, get_dino_transform, make_results_folder +from os.path import join +from pytictac import Timer import matplotlib.pyplot as plt import torch -from torchvision import transforms as T -from pathlib import PurePath, Path -import os -import cv2 +import itertools def test_feature_extractor(): - segmentation_type = "none" - feature_type = "dino" - device = "cuda" if torch.cuda.is_available() else "cpu" - fe = FeatureExtractor(device, segmentation_type=segmentation_type, feature_type=feature_type) - - transform = T.Compose( - [ - T.Resize(448, T.InterpolationMode.NEAREST), - T.CenterCrop(448), - ] - ) - - np_img = cv2.imread(os.path.join(WVN_ROOT_DIR, "assets/images/forest_clean.png")) - img = torch.from_numpy(cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)).to(device) - img = img.permute(2, 0, 1) - img = (img.type(torch.float32) / 255)[None] - adj, feat, seg, center = fe.extract(transform(img.clone())) - - p = PurePath(WVN_ROOT_DIR).joinpath( - "results", - "test_feature_extractor", - f"forest_clean_graph_{segmentation_type}.png", - ) - Path(p.parent).mkdir(parents=True, exist_ok=True) - - # Plot result as in colab - fig, ax = plt.subplots(1, 2, figsize=(5 * 3, 5)) - - ax[0].imshow(transform(img).permute(0, 2, 3, 1)[0].cpu()) - ax[0].set_title("Image") - ax[1].imshow(seg.cpu(), cmap=plt.colormaps.get("inferno")) - ax[1].set_title("Segmentation") - plt.tight_layout() - - # Store results to test directory - img = get_img_from_fig(fig) - img.save(str(p)) + segmentation_types = ["none", "grid", "slic", "random", "stego"] + feature_types = ["dino", "dinov2", "stego"] + backbone_types = ["vit_small", "vit_base", "vit_small_reg", "vit_base_reg"] + + for seg_type, feat_type, back_type in itertools.product(segmentation_types, feature_types, backbone_types): + if seg_type == "stego" and feat_type != "stego": + continue + + with Timer(f"Running seg [{seg_type}], feat [{feat_type}], backbone [{back_type}]"): + try: + fe = FeatureExtractor( + device, segmentation_type=seg_type, feature_type=feat_type, backbone_type=back_type + ) + except Exception: + print("Not available") + continue + + img = load_test_image().to(device) + transform = get_dino_transform() + outpath = make_results_folder("test_feature_extractor") + + # Compute + edges, feat, seg, center, dense_feat = fe.extract(transform(img.clone())) + + # Plot result as in colab + fig, ax = plt.subplots(1, 2, figsize=(5 * 3, 5)) + + ax[0].imshow(transform(img).permute(0, 2, 3, 1)[0].cpu()) + ax[0].set_title("Image") + ax[1].imshow(seg.cpu(), cmap=plt.colormaps.get("inferno")) + ax[1].set_title("Segmentation") + plt.tight_layout() + + # Store results to test directory + img = get_img_from_fig(fig) + img.save(join(outpath, f"forest_clean_graph_{seg_type}_{feat_type}.png")) if __name__ == "__main__": diff --git a/tests/test_image_projector.py b/tests/test_image_projector.py index cf2eb410..61eea3c7 100644 --- a/tests/test_image_projector.py +++ b/tests/test_image_projector.py @@ -7,24 +7,19 @@ def test_image_projector(): def test_supervision_projection(): - from wild_visual_navigation import WVN_ROOT_DIR from wild_visual_navigation.image_projector import ImageProjector from wild_visual_navigation.visu import get_img_from_fig - from PIL import Image + from wild_visual_navigation.utils.testing import load_test_image, make_results_folder from liegroups.torch import SE3, SO3 import matplotlib.pyplot as plt import torch - import torchvision.transforms as transforms - import os from os.path import join from kornia.utils import tensor_to_image - from stego.src import remove_axes + from stego.utils import remove_axes import random - to_tensor = transforms.ToTensor() - # Create test directory - os.makedirs(join(WVN_ROOT_DIR, "results", "test_image_projector"), exist_ok=True) + outpath = make_results_folder("test_image_projector") # Define number of cameras (batch) B = 100 @@ -36,19 +31,17 @@ def test_supervision_projection(): # Extrisics pose_camera_in_world = torch.eye(4)[None] - # Image size - H = 1080 - W = 1440 + H = torch.tensor(1080) + W = torch.tensor(1440) # Create projector im = ImageProjector(K, H, W) # Load image - pil_image = Image.open(join(WVN_ROOT_DIR, "assets/images/forest_clean.png")) + torch_image = load_test_image() - # Convert to torch - torch_image = to_tensor(pil_image) + # Resize torch_image = im.resize_image(torch_image) mask = (torch_image * 0.0)[None] @@ -68,7 +61,7 @@ def test_supervision_projection(): delta = SE3(R_WC, rho).as_matrix()[None] # Pose matrix of camera in world frame pose_base_in_world = pose_base_in_world @ delta pose_footprint_in_base = torch.eye(4)[None] - print(delta, pose_base_in_world) + # print(delta, pose_base_in_world) twist = torch.rand((3,)) supervision = torch.rand((10,)) @@ -105,7 +98,7 @@ def test_supervision_projection(): fig, ax = plt.subplots(1, 2, figsize=(2 * 5, 5)) ax[0].imshow(tensor_to_image(torch_image)) ax[0].set_title("Image") - ax[1].imshow(tensor_to_image(mask)) + ax[1].imshow(tensor_to_image(mask[0])) ax[1].set_title("Labels") remove_axes(ax) @@ -113,14 +106,7 @@ def test_supervision_projection(): # Store results to test directory img = get_img_from_fig(fig) - img.save( - join( - WVN_ROOT_DIR, - "results", - "test_image_projector", - "forest_clean_supervision_projection.png", - ) - ) + img.save(join(outpath, "forest_clean_supervision_projection.png")) if __name__ == "__main__": diff --git a/tests/test_kornia.py b/tests/test_kornia.py index c44cedfe..cd599ac4 100644 --- a/tests/test_kornia.py +++ b/tests/test_kornia.py @@ -1,6 +1,5 @@ -from wild_visual_navigation import WVN_ROOT_DIR from wild_visual_navigation.visu import get_img_from_fig -import os +from wild_visual_navigation.utils.testing import make_results_folder from os.path import join import matplotlib.pyplot as plt import torch @@ -47,8 +46,8 @@ def test_draw_polygon(): # Draw k_out = draw_convex_polygon(img, poly, color) # Show - os.makedirs(join(WVN_ROOT_DIR, "results", "test_kornia"), exist_ok=True) + outpath = make_results_folder("test_kornia") fig, ax = plt.subplots(1, 1, figsize=(5, 5)) ax.imshow(tensor_to_image(k_out)) out_img = get_img_from_fig(fig) - out_img.save(join(WVN_ROOT_DIR, "results", "test_kornia", "polygon_test.png")) + out_img.save(join(outpath, "polygon_test.png")) diff --git a/tests/test_monitoring.py b/tests/test_monitoring.py index 7829bd10..91c36fe2 100644 --- a/tests/test_monitoring.py +++ b/tests/test_monitoring.py @@ -1,5 +1,5 @@ from wild_visual_navigation.utils import SystemLevelGpuMonitor, accumulate_memory -from pytictac import SystemLevelTimer, accumulate_time +from pytictac import accumulate_time import time import torch @@ -15,13 +15,13 @@ def __init__(self): @accumulate_time def test_memory_then_timing(self, s): time.sleep(s / 1000) - self.tensors.append(torch.zeros((s, s), device="cuda")) + self.tensors.append(torch.zeros((10 * s, 10 * s), device="cuda")) @accumulate_time @accumulate_memory def test_timing_then_memory(self, s): time.sleep(s / 1000) - self.tensors.append(torch.zeros((4, s, s), device="cuda")) + self.tensors.append(torch.zeros((4, 10 * s, 10 * s), device="cuda")) # Create objects my_test = MyTest() @@ -33,13 +33,13 @@ def test_timing_then_memory(self, s): store_samples=True, skip_n_samples=1, ) - time_monitor = SystemLevelTimer( - objects=[my_test], - names=["test"], - ) + # time_monitor = ClassContextTimer( + # objects=[my_test], + # names=["test"], + # ) # Run loop - for n in range(400): + for n in range(100): print(f"step {n}") step = n t = n / 10 @@ -48,7 +48,7 @@ def test_timing_then_memory(self, s): my_test.test_timing_then_memory(n) gpu_monitor.store("/tmp") - time_monitor.store("/tmp") + # time_monitor.store("/tmp") if __name__ == "__main__": diff --git a/tests/test_optical_flow.py b/tests/test_optical_flow.py index f81a84ff..065b5b0b 100644 --- a/tests/test_optical_flow.py +++ b/tests/test_optical_flow.py @@ -1,15 +1,22 @@ -from pytorch_pwc.network import PwcFlowEstimator -from pytorch_pwc import PWC_ROOT_DIR -import torch -from wild_visual_navigation import WVN_ROOT_DIR -from wild_visual_navigation.visu import LearningVisualizer -import numpy as np -import PIL -import os -from pytictac import Timer +import pytest +import sys -if __name__ == "__main__": +try: + from pytorch_pwc.network import PwcFlowEstimator + from pytorch_pwc import PWC_ROOT_DIR +except ImportError: + pass + + +@pytest.mark.skipif("pytorch_pwc" not in sys.modules, reason="requires the pytorch_pwc library") +def pytorch_pwc_test(): import os + import torch + from wild_visual_navigation import WVN_ROOT_DIR + from wild_visual_navigation.visu import LearningVisualizer + import numpy as np + import PIL + from pytictac import Timer tenOne = torch.FloatTensor( np.ascontiguousarray( @@ -39,3 +46,7 @@ visu = LearningVisualizer(p_visu=os.path.join(WVN_ROOT_DIR, "results/test_visu"), store=True) visu.plot_optical_flow(res, tenOne, tenTwo) print("done") + + +if __name__ == "__main__": + pytorch_pwc_test() diff --git a/wild_visual_navigation/cfg/experiment_params.py b/wild_visual_navigation/cfg/experiment_params.py index 63412abf..dc3c62bc 100644 --- a/wild_visual_navigation/cfg/experiment_params.py +++ b/wild_visual_navigation/cfg/experiment_params.py @@ -102,7 +102,7 @@ class ModelParams: @dataclass class SimpleMlpCfgParams: - input_size: int = 384 + input_size: int = 384 # 90 for stego, 384 for dino hidden_sizes: List[int] = field(default_factory=lambda: [256, 32, 1]) reconstruction: bool = True @@ -125,7 +125,7 @@ class SimpleGcnCfgParams: @dataclass class LinearRnvpCfgParams: - input_dim: int = 384 + input_size: int = 384 coupling_topology: List[int] = field(default_factory=lambda: [200]) mask_type: str = "odds" conditioning_size: int = 0 diff --git a/wild_visual_navigation/cfg/ros_params.py b/wild_visual_navigation/cfg/ros_params.py index adc28703..5b9b2578 100644 --- a/wild_visual_navigation/cfg/ros_params.py +++ b/wild_visual_navigation/cfg/ros_params.py @@ -29,8 +29,8 @@ class RosLearningNodeParams: segmentation_type: str feature_type: str dino_patch_size: int # 8 or 16; 8 is finer + dino_backbone: str # vit_small, vit_base slic_num_components: int - dino_dim: int # 90 or 384; 384 is better confidence_std_factor: float scale_traversability: bool # This parameter needs to be false when using the anomaly detection model scale_traversability_max_fpr: float @@ -76,8 +76,8 @@ class RosFeatureExtractorNodeParams: segmentation_type: str feature_type: str dino_patch_size: int # 8 or 16; 8 is finer + dino_backbone: str # vit_small, vit_base slic_num_components: int - dino_dim: int # 90 or 384; 384 is better # ConfidenceGenerator confidence_std_factor: float @@ -94,3 +94,6 @@ class RosFeatureExtractorNodeParams: device: str log_confidence: bool verbose: bool + + # Threads + image_callback_rate: float # hertz diff --git a/wild_visual_navigation/feature_extractor/dino_interface.py b/wild_visual_navigation/feature_extractor/dino_interface.py index 1df83fdc..61e94f9a 100644 --- a/wild_visual_navigation/feature_extractor/dino_interface.py +++ b/wild_visual_navigation/feature_extractor/dino_interface.py @@ -1,5 +1,3 @@ -from wild_visual_navigation import WVN_ROOT_DIR -import os from os.path import join import torch.nn.functional as F import torch @@ -17,7 +15,6 @@ def __init__( input_size: int = 448, backbone_type: str = "vit_small", patch_size: int = 8, - dim: int = 384, projection_type: str = None, # nonlinear or None dropout_p: float = 0, # True or False pretrained_weights: str = None, @@ -31,7 +28,6 @@ def __init__( "backbone_type": backbone_type, "input_size": input_size, "patch_size": patch_size, - "dim": dim, "projection_type": projection_type, "dropout_p": dropout_p, "pretrained_weights": pretrained_weights, @@ -112,20 +108,16 @@ def run_dino_interfacer(): from pytictac import Timer from wild_visual_navigation.visu import get_img_from_fig + from wild_visual_navigation.utils.testing import load_test_image, make_results_folder import matplotlib.pyplot as plt from stego.utils import remove_axes - import cv2 # Create test directory - os.makedirs(join(WVN_ROOT_DIR, "results", "test_dino_interfacer"), exist_ok=True) + outpath = make_results_folder("test_dino_interfacer") # Inference model device = "cuda" if torch.cuda.is_available() else "cpu" - p = join(WVN_ROOT_DIR, "assets/images/forest_clean.png") - np_img = cv2.imread(p) - img = torch.from_numpy(cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)).to(device) - img = img.permute(2, 0, 1) - img = (img.type(torch.float32) / 255)[None] + img = load_test_image().to(device) img = F.interpolate(img, scale_factor=0.25) plot = False @@ -163,9 +155,7 @@ def run_dino_interfacer(): out_img = get_img_from_fig(fig) out_img.save( join( - WVN_ROOT_DIR, - "results", - "test_dino_interfacer", + outpath, f"forest_clean_dino_feat{i:02}_{di.input_size}_{di.backbone_type}_{di.vit_patch_size}.png", ) ) @@ -196,9 +186,7 @@ def run_dino_interfacer(): out_img = get_img_from_fig(fig) out_img.save( join( - WVN_ROOT_DIR, - "results", - "test_dino_interfacer", + outpath, f"forest_clean_{di.backbone}_{di.input_size}_{di.backbone_type}_{di.vit_patch_size}.png", ) ) diff --git a/wild_visual_navigation/feature_extractor/dino_trt_interface.py b/wild_visual_navigation/feature_extractor/dino_trt_interface.py index 006b6c0e..749b0f60 100644 --- a/wild_visual_navigation/feature_extractor/dino_trt_interface.py +++ b/wild_visual_navigation/feature_extractor/dino_trt_interface.py @@ -124,24 +124,22 @@ def run_dino_trt_interfacer(): """Performance inference using stego and stores result as an image.""" from wild_visual_navigation.visu import get_img_from_fig + from wild_visual_navigation.testing import load_test_image, get_dino_transform + from wild_visual_navigation.utils.testing import make_results_folder import matplotlib.pyplot as plt from stego.src import remove_axes - import cv2 # Create test directory - os.makedirs(join(WVN_ROOT_DIR, "results", "test_dino_trt_interfacer"), exist_ok=True) + outpath = make_results_folder("test_dino_trt_interfacer") # Inference model device = "cuda" if torch.cuda.is_available() else "cpu" di = DinoTrtInterface(device=device) - p = join(WVN_ROOT_DIR, "assets/images/forest_clean.png") - np_img = cv2.imread(p) - img = torch.from_numpy(cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)).to(device) - img = img.permute(2, 0, 1) - img = (img.type(torch.float32) / 255)[None] + img = load_test_image().to(device) + transform = get_dino_transform() # Inference with DINO - feat_dino = di.inference(di.transform(img), interpolate=False) + feat_dino = di.inference(transform(img), interpolate=False) # Fix size of DINO features to match input image's size B, D, H, W = img.shape @@ -171,7 +169,7 @@ def run_dino_trt_interfacer(): # Store results to test directory img = get_img_from_fig(fig) - img.save(join(WVN_ROOT_DIR, "results", "test_dino_trt_interfacer", "forest_clean_dino.png")) + img.save(join(outpath, "forest_clean_dino.png")) if __name__ == "__main__": diff --git a/wild_visual_navigation/feature_extractor/feature_extractor.py b/wild_visual_navigation/feature_extractor/feature_extractor.py index 055dd758..38ba8267 100644 --- a/wild_visual_navigation/feature_extractor/feature_extractor.py +++ b/wild_visual_navigation/feature_extractor/feature_extractor.py @@ -47,13 +47,13 @@ def __init__( ) elif "dino" in self._feature_type: - self._feature_dim = 90 + self._feature_dim = 384 self._extractor = DinoInterface( device=device, input_size=input_size, patch_size=kwargs.get("patch_size", 8), - backbone=kwargs.get("backbone", "dino"), - dim=kwargs.get("dino_dim", 384), + backbone=kwargs.get("backbone", self._feature_type), + backbone_type=kwargs.get("backbone_type", "vit_base"), ) elif self._feature_type == "sift": @@ -92,6 +92,7 @@ def extract(self, img, **kwargs): H, W = img.shape[2:] nr = kwargs.get("n_random_pixels", 100) + seg = torch.full((H * W,), -1, dtype=torch.long, device=self._device) indices = torch.randperm(H * W, device=self._device)[:nr] seg[indices] = torch.arange(0, nr, device=self._device) @@ -167,7 +168,7 @@ def compute_segments(self, img: torch.tensor, **kwargs): # Extract centers centers = self.segment_extractor.centers(seg) - return edges.T, seg, centers + return edges.T, seg[0, 0], centers def segment_pixelwise(self, img, **kwargs): # Generate pixel-wise segmentation @@ -186,7 +187,7 @@ def segment_pixelwise(self, img, **kwargs): ver_edges = torch.cat((seg[:-1, :].reshape(-1, 1), seg[1:, :].reshape(-1, 1)), dim=1) edges = torch.cat((hor_edges, ver_edges), dim=0) - return edges, seg, centers + return edges, seg[None, None], centers def segment_grid(self, img, **kwargs): cell_size = kwargs.get("cell_size", 32) @@ -246,7 +247,7 @@ def compute_features(self, img: torch.tensor, seg: torch.tensor, center: torch.t elif self._feature_type == "sift": feat = self.compute_sift(img, seg, center, **kwargs) - elif self._feature_type == "dino": + elif "dino" in self._feature_type: feat = self.compute_dino(img, seg, center, **kwargs) elif self._feature_type == "stego": @@ -290,7 +291,11 @@ def compute_torchvision(self, img: torch.tensor, seg: torch.tensor, center: torc @torch.no_grad() def compute_stego(self, img: torch.tensor, seg: torch.tensor, center: torch.tensor, **kwargs): - return self._extractor.features + try: + return self._extractor.features + except Exception: + self.segment_stego(img, **kwargs) + return self._extractor.features def sparsify_features(self, dense_features: torch.tensor, seg: torch.tensor, cumsum_trick=False): if self._feature_type not in ["histogram"] and self._segmentation_type not in ["none"]: @@ -375,56 +380,9 @@ def sparsify_features(self, dense_features: torch.tensor, seg: torch.tensor, cum sparse_features = [] for i in range(seg.max() + 1): m = seg == i - x, y = torch.where(m[0, 0]) + x, y = torch.where(m) feat = dense_features[0, :, x, y].mean(dim=1) sparse_features.append(feat) return torch.stack(sparse_features, dim=1).T else: return dense_features - - -def run_feature_extractor(): - """Tests feature extractor""" - import os - import cv2 - from os.path import join - from pytictac import Timer - from torchvision import transforms as T - from wild_visual_navigation import WVN_ROOT_DIR - - # Create test directory - os.makedirs(join(WVN_ROOT_DIR, "results", "test_feature_extractor"), exist_ok=True) - - # Inference model - device = "cuda" if torch.cuda.is_available() else "cpu" - - p = join(WVN_ROOT_DIR, "assets/images/forest_clean.png") - np_img = cv2.imread(p) - np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) - img = torch.from_numpy(np_img).to(device) - img = img.permute(2, 0, 1) - img = (img.type(torch.float32) / 255)[None] - transform = T.Compose( - [ - T.Resize(448, T.InterpolationMode.NEAREST), - T.CenterCrop(448), - ] - ) - img = transform(img) - - # create feature extractor - fe = FeatureExtractor(device=device, segmentation_type="slic", feature_type="dino") - with Timer(f"SLIC-DINO"): - edges, feat, seg, center, dense_feat = fe.extract(img) - - fe = FeatureExtractor(device=device, segmentation_type="grid", feature_type="dino") - with Timer(f"GRID-DINO"): - edges, feat, seg, center, dense_feat = fe.extract(img) - - fe = FeatureExtractor(device=device, segmentation_type="stego", feature_type="stego") - with Timer(f"STEGO-STEGO"): - edges, feat, seg, center, dense_feat = fe.extract(img) - - -if __name__ == "__main__": - run_feature_extractor() diff --git a/wild_visual_navigation/feature_extractor/stego_interface.py b/wild_visual_navigation/feature_extractor/stego_interface.py index 83a5ecdc..d1a52098 100644 --- a/wild_visual_navigation/feature_extractor/stego_interface.py +++ b/wild_visual_navigation/feature_extractor/stego_interface.py @@ -1,11 +1,10 @@ -from wild_visual_navigation import WVN_ROOT_DIR -import os from os.path import join import torch.nn.functional as F import torch from torchvision import transforms as T from omegaconf import OmegaConf +from pytictac import Timer from stego import STEGO_ROOT_DIR from stego.stego import Stego from stego.data import create_cityscapes_colormap @@ -17,7 +16,7 @@ def __init__( device: str, input_size: int = 448, model_path: str = f"{STEGO_ROOT_DIR}/models/stego_cocostuff27_vit_base_5_cluster_linear_fine_tuning.ckpt", - n_image_clusters: int = 20, + n_image_clusters: int = 40, run_crf: bool = True, run_clustering: bool = False, cfg: OmegaConf = OmegaConf.create({}), @@ -79,10 +78,14 @@ def inference(self, img: torch.tensor): # assert 1 == img.shape[0] # Resize image and normalize + # with Timer("input normalization"): resized_img = self._transform(img).to(self._device) # Run STEGO + # with Timer("compute code"): self._code = self._model.get_code(resized_img) + + # with Timer("compute postprocess"): self._cluster_pred, self._linear_pred = self._model.postprocess( code=self._code, img=resized_img, @@ -92,6 +95,7 @@ def inference(self, img: torch.tensor): ) # resize and interpolate features + # with Timer("interpolate output"): B, D, H, W = img.shape new_features_size = (H, H) # pad = int((W - H) / 2) @@ -128,15 +132,13 @@ def features(self): def run_stego_interfacer(): """Performance inference using stego and stores result as an image.""" - - from pytictac import Timer from wild_visual_navigation.visu import get_img_from_fig + from wild_visual_navigation.utils.testing import load_test_image, make_results_folder from stego.utils import remove_axes import matplotlib.pyplot as plt - import cv2 # Create test directory - os.makedirs(join(WVN_ROOT_DIR, "results", "test_stego_interfacer"), exist_ok=True) + outpath = make_results_folder("test_stego_interfacer") # Inference model device = "cuda" if torch.cuda.is_available() else "cpu" @@ -149,15 +151,10 @@ def run_stego_interfacer(): n_image_clusters=20, ) - p = join(WVN_ROOT_DIR, "assets/images/forest_clean.png") - np_img = cv2.imread(p) - np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) - img = torch.from_numpy(np_img).to(device) - img = img.permute(2, 0, 1) - img = (img.type(torch.float32) / 255)[None] + img = load_test_image().to(device) img = F.interpolate(img, scale_factor=0.5) - with Timer(f"Stego input {si.input_size}"): + with Timer(f"Stego input {si.input_size}\n"): linear_pred, cluster_pred = si.inference(img) # Plot result as in colab @@ -175,9 +172,7 @@ def run_stego_interfacer(): img = get_img_from_fig(fig) img.save( join( - WVN_ROOT_DIR, - "results", - "test_stego_interfacer", + outpath, f"forest_clean_stego_{si.input_size}.png", ) ) diff --git a/wild_visual_navigation/feature_extractor/torchvision_interface.py b/wild_visual_navigation/feature_extractor/torchvision_interface.py index ea2ff442..75961623 100644 --- a/wild_visual_navigation/feature_extractor/torchvision_interface.py +++ b/wild_visual_navigation/feature_extractor/torchvision_interface.py @@ -1,9 +1,7 @@ from wild_visual_navigation import WVN_ROOT_DIR -import os from os.path import join from torch import nn -import torch.nn.functional as F import torch import torchvision.models as models @@ -120,31 +118,19 @@ def inference(self, img: torch.tensor, interpolate: bool = False): def run_torch_vision_model_interfacer(): """Performance inference using stego and stores result as an image.""" - - from pytictac import Timer - from wild_visual_navigation.visu import get_img_from_fig - import matplotlib.pyplot as plt - from stego.src import remove_axes - import cv2 - - # Create test directory - # os.makedirs(join(WVN_ROOT_DIR, "results", "test_torchvision_interfacer"), exist_ok=True) + from wild_visual_navigation.utils.testing import load_test_image # Inference model device = "cuda" if torch.cuda.is_available() else "cpu" - p = join(WVN_ROOT_DIR, "assets/images/forest_clean.png") - np_img = cv2.imread(p) - img = torch.from_numpy(cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)).to(device) - img = img.permute(2, 0, 1) - img = (img.type(torch.float32) / 255)[None] + img = load_test_image().to(device) - plot = False - save_features = True + # plot = False + # save_features = True size = 448 model_type = "resnet18" - di = TorchVisionInterface(model_type=model_type, input_size=488) + di = TorchVisionInterface(model_type=model_type, input_size=size) di.to(device) img.to(device) res = di(img) diff --git a/wild_visual_navigation/image_projector/image_projector.py b/wild_visual_navigation/image_projector/image_projector.py index 7f3a1a8b..e077de89 100644 --- a/wild_visual_navigation/image_projector/image_projector.py +++ b/wild_visual_navigation/image_projector/image_projector.py @@ -1,6 +1,4 @@ -from wild_visual_navigation import WVN_ROOT_DIR from pytictac import Timer -import os from os.path import join import torch from torchvision import transforms as T @@ -204,17 +202,14 @@ def run_image_projector(): from wild_visual_navigation.utils import ( make_polygon_from_points, ) - from PIL import Image + from wild_visual_navigation.utils.testing import load_test_image, make_results_folder import matplotlib.pyplot as plt import torch - import torchvision.transforms as transforms from kornia.utils import tensor_to_image - from stego.src import remove_axes - - to_tensor = transforms.ToTensor() + from stego.utils import remove_axes # Create test directory - os.makedirs(join(WVN_ROOT_DIR, "results", "test_image_projector"), exist_ok=True) + outpath = make_results_folder("test_image_projector") # Define number of cameras (batch) B = 10 @@ -234,17 +229,14 @@ def run_image_projector(): R_WC = SO3.from_rpy(phi) # Rotation matrix from roll-pitch-yaw pose_camera_in_world[i] = SE3(R_WC, rho).as_matrix() # Pose matrix of camera in world frame # Image size - H = 1080 - W = 1440 + H = torch.tensor(1080) + W = torch.tensor(1440) # Create projector im = ImageProjector(K, H, W) # Load image - pil_img = Image.open(join(WVN_ROOT_DIR, "assets/images/forest_clean.png")) - - # Convert to torch - k_img = to_tensor(pil_img) + k_img = load_test_image() k_img = k_img.expand(B, 3, H, W) k_img = im.resize_image(k_img) @@ -292,9 +284,7 @@ def run_image_projector(): img = get_img_from_fig(fig) img.save( join( - WVN_ROOT_DIR, - "results", - "test_image_projector", + outpath, "forest_clean_image_projector.png", ) ) diff --git a/wild_visual_navigation/model/linear_rnvp.py b/wild_visual_navigation/model/linear_rnvp.py index 7941e219..d91eac47 100644 --- a/wild_visual_navigation/model/linear_rnvp.py +++ b/wild_visual_navigation/model/linear_rnvp.py @@ -70,7 +70,7 @@ class LinearCouplingLayer(nn.Module): def __init__( self, - input_dim, + input_size, mask, network_topology, conditioning_size=None, @@ -82,14 +82,14 @@ def __init__( conditioning_size = 0 if network_topology is None or len(network_topology) == 0: - network_topology = [input_dim] + network_topology = [input_size] self.register_buffer("mask", mask) - self.dim = input_dim + self.dim = input_size self.s = [ - nn.Linear(input_dim + conditioning_size, network_topology[0]), + nn.Linear(input_size + conditioning_size, network_topology[0]), nn.ReLU(), ] @@ -99,9 +99,9 @@ def __init__( self.s.extend([nn.Linear(t_p, t), nn.ReLU()]) if single_function: - input_dim = input_dim * 2 + input_size = input_size * 2 - ll = nn.Linear(network_topology[-1], input_dim) + ll = nn.Linear(network_topology[-1], input_size) self.s.append(ll) self.s = nn.Sequential(*self.s) @@ -216,7 +216,7 @@ class LinearRnvp(nn.Module): def __init__( self, - input_dim, + input_size, coupling_topology, flow_n=2, use_permutation=False, @@ -228,26 +228,26 @@ def __init__( ): super().__init__() - self.register_buffer("prior_mean", torch.zeros(input_dim)) # Normal Gaussian with zero mean - self.register_buffer("prior_var", torch.ones(input_dim)) # Normal Gaussian with unit variance + self.register_buffer("prior_mean", torch.zeros(input_size)) # Normal Gaussian with zero mean + self.register_buffer("prior_var", torch.ones(input_size)) # Normal Gaussian with unit variance if mask_type == "odds": - mask = torch.arange(0, input_dim).float() % 2 + mask = torch.arange(0, input_size).float() % 2 elif mask_type == "half": - mask = torch.zeros(input_dim) - mask[: input_dim // 2] = 1 + mask = torch.zeros(input_size) + mask[: input_size // 2] = 1 else: assert False if coupling_topology is None: - coupling_topology = [input_dim // 2, input_dim // 2] + coupling_topology = [input_size // 2, input_size // 2] blocks = [] for i in range(flow_n): blocks.append( LinearCouplingLayer( - input_dim, + input_size, mask, network_topology=coupling_topology, conditioning_size=conditioning_size, @@ -255,12 +255,12 @@ def __init__( ) ) if use_permutation: - blocks.append(Permutation(input_dim)) + blocks.append(Permutation(input_size)) else: mask = 1 - mask if batch_norm: - blocks.append(LinearBatchNorm(input_dim)) + blocks.append(LinearBatchNorm(input_size)) self.flows = SequentialFlow(*blocks) diff --git a/wild_visual_navigation/utils/__init__.py b/wild_visual_navigation/utils/__init__.py index 764f14c2..35d04975 100644 --- a/wild_visual_navigation/utils/__init__.py +++ b/wild_visual_navigation/utils/__init__.py @@ -24,3 +24,4 @@ accumulate_memory, ) from .loss import TraversabilityLoss, AnomalyLoss +from .testing import load_test_image, get_dino_transform, make_results_folder diff --git a/wild_visual_navigation/utils/confidence_generator.py b/wild_visual_navigation/utils/confidence_generator.py index 2be2e77f..2a013e7b 100644 --- a/wild_visual_navigation/utils/confidence_generator.py +++ b/wild_visual_navigation/utils/confidence_generator.py @@ -1,4 +1,5 @@ from wild_visual_navigation.utils import KalmanFilter +from wild_visual_navigation import WVN_ROOT_DIR import torch import os from collections import deque @@ -10,14 +11,13 @@ def __init__( std_factor: float = 0.7, method: str = "running_mean", log_enabled: bool = False, - log_folder: str = "/tmp", + log_folder: str = f"{WVN_ROOT_DIR}/results", anomaly_detection: bool = False, ): """Returns a confidence value for each number Args: - std_factor (float, optional): _description_. Defaults to 2.0. - device (str, optional): _description_. Defaults to "cpu". + std_factor (float, optional): _description_. Defaults to 0.7. """ super(ConfidenceGenerator, self).__init__() self.std_factor = std_factor @@ -51,13 +51,13 @@ def __init__( running_sum = torch.zeros(1, dtype=torch.float64) running_sum_of_squares = torch.zeros(1, dtype=torch.float64) - # self.running_n = torch.nn.Parameter(running_n, requires_grad=False) - # self.running_sum = torch.nn.Parameter(running_sum, requires_grad=False) - # self.running_sum_of_squares = torch.nn.Parameter(running_sum_of_squares, requires_grad=False) + self.running_n = torch.nn.Parameter(running_n, requires_grad=False) + self.running_sum = torch.nn.Parameter(running_sum, requires_grad=False) + self.running_sum_of_squares = torch.nn.Parameter(running_sum_of_squares, requires_grad=False) - self.running_n = running_n.to("cuda") - self.running_sum = running_sum.to("cuda") - self.running_sum_of_squares = running_sum_of_squares.to("cuda") + # self.running_n = running_n.to("cuda") + # self.running_sum = running_sum.to("cuda") + # self.running_sum_of_squares = running_sum_of_squares.to("cuda") self._update = self.update_running_mean self._reset = self.reset_running_mean @@ -211,12 +211,7 @@ def get_dict(self): if __name__ == "__main__": cg = ConfidenceGenerator() - for i in range(100000): - inp = ( - torch.rand( - 10, - ) - * 10 - ) - res = cg.update(inp, inp) - print("inp ", inp, " res ", res, "std", cg.std) + for i in range(1000): + inp = torch.rand(10) * 10 + res = cg.update(inp, inp, step=i) + # print("inp ", inp, " res ", res, "std", cg.std) diff --git a/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml b/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml index 9b7bb318..66aa0093 100644 --- a/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml +++ b/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml @@ -21,16 +21,16 @@ image_graph_dist_thr: 0.2 # meters supervision_graph_dist_thr: 0.1 # meters network_input_image_height: 224 # 448 network_input_image_width: 224 # 448 -segmentation_type: "stego" -feature_type: "stego" # TODO verify this here +segmentation_type: "random" # Options: slic, grid, random, stego +feature_type: "dino" # Options: dino, dinov2, stego dino_patch_size: 8 # 8 or 16; 8 is finer +dino_backbone: vit_small slic_num_components: 100 -dino_dim: 384 # 90 or 384; 384 is better confidence_std_factor: 4.0 scale_traversability: False # This parameter needs to be false when using the anomaly detection model scale_traversability_max_fpr: 0.25 min_samples_for_training: 5 -prediction_per_pixel: False +prediction_per_pixel: True traversability_threshold: 0.55 clip_to_binary: False vis_node_index: 10 diff --git a/wild_visual_navigation_ros/scripts/overlay_images.py b/wild_visual_navigation_ros/scripts/overlay_images.py index a6728b12..cffb533a 100644 --- a/wild_visual_navigation_ros/scripts/overlay_images.py +++ b/wild_visual_navigation_ros/scripts/overlay_images.py @@ -1,6 +1,6 @@ import message_filters import rospy -from sensor_msgs.msg import Image, CameraInfo, CompressedImage +from sensor_msgs.msg import Image # , CameraInfo # , CompressedImage import wild_visual_navigation_ros.ros_converter as rc from wild_visual_navigation.visu import LearningVisualizer import sys @@ -12,8 +12,9 @@ def __init__(self): self.value_sub_topic = rospy.get_param("~value_sub_topic") self.image_pub_topic = rospy.get_param("~image_pub_topic") - self.input_pub = rospy.Publisher(f"~{self.image_pub_topic}", Image, queue_size=10) + self._pub = rospy.Publisher(f"~{self.image_pub_topic}", Image, queue_size=1) self._visualizer = LearningVisualizer() + image_sub = message_filters.Subscriber(self.image_sub_topic, Image) trav_sub = message_filters.Subscriber(self.value_sub_topic, Image) sync = message_filters.ApproximateTimeSynchronizer([image_sub, trav_sub], queue_size=2, slop=0.5) @@ -23,7 +24,9 @@ def callback(self, image_msg, trav_msgs): torch_image = rc.ros_image_to_torch(image_msg, device="cpu") torch_trav = rc.ros_image_to_torch(trav_msgs, device="cpu", desired_encoding="passthrough") img_out = self._visualizer.plot_detectron_classification(torch_image, torch_trav.clip(0, 1)) - self.input_pub.publish(rc.numpy_to_ros_image(img_out)) + ros_msg = rc.numpy_to_ros_image(img_out) + ros_msg.header.stamp = image_msg.header.stamp + self._pub.publish(ros_msg) if __name__ == "__main__": @@ -36,7 +39,7 @@ def callback(self, image_msg, trav_msgs): else: nr = "0" # Handle case when no arg is set rospy.init_node(f"wild_visual_navigation_visu_{nr}") - except: + except Exception: rospy.init_node("wild_visual_navigation_visu") wvn = ImageOverlayNode() diff --git a/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py b/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py index b3372d45..aa42a263 100644 --- a/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py +++ b/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py @@ -15,14 +15,15 @@ import os import torch import numpy as np +import torch.nn.functional as F +import signal +import sys +import traceback from omegaconf import OmegaConf, read_write from torch_geometric.data import Data -import torch.nn.functional as F from threading import Thread, Event from prettytable import PrettyTable from termcolor import colored -import signal -import sys class WvnFeatureExtractor: @@ -34,6 +35,9 @@ def __init__(self, node_name): self._node_name = node_name self._load_model_counter = 0 + # Timers to control the rate of the subscriber + self.last_image_ts = rospy.get_time() + self.model = get_model(self.params.model).to(self.ros_params.device) self.model.eval() @@ -41,9 +45,10 @@ def __init__(self, node_name): self.ros_params.device, segmentation_type=self.ros_params.segmentation_type, feature_type=self.ros_params.feature_type, + patch_size=self.ros_params.dino_patch_size, + backbone_type=self.ros_params.dino_backbone, input_size=self.ros_params.network_input_image_height, slic_num_components=self.ros_params.slic_num_components, - dino_dim=self.ros_params.dino_dim, ) if not self.anomaly_detection: @@ -54,12 +59,12 @@ def __init__(self, node_name): ) self.ros_params.scale_traversability = True else: - self.traversability_loss = AnomalyLoss( + self.anomaly_loss = AnomalyLoss( **self.params.loss_anomaly, log_enabled=self.params.general.log_confidence, log_folder=self.params.general.model_path, ) - self.traversability_loss.to(self.ros_params.device) + self.anomaly_loss.to(self.ros_params.device) self.ros_params.scale_traversability = False if self.ros_params.verbose: @@ -209,12 +214,12 @@ def setup_ros(self, setup_fully=True): trav_pub = rospy.Publisher( f"/wild_visual_navigation_node/{cam}/traversability", Image, - queue_size=10, + queue_size=1, ) info_pub = rospy.Publisher( f"/wild_visual_navigation_node/{cam}/camera_info", CameraInfo, - queue_size=10, + queue_size=1, ) self.camera_handler[cam]["trav_pub"] = trav_pub self.camera_handler[cam]["info_pub"] = info_pub @@ -226,7 +231,7 @@ def setup_ros(self, setup_fully=True): input_pub = rospy.Publisher( f"/wild_visual_navigation_node/{cam}/image_input", Image, - queue_size=10, + queue_size=1, ) self.camera_handler[cam]["input_pub"] = input_pub @@ -234,7 +239,7 @@ def setup_ros(self, setup_fully=True): conf_pub = rospy.Publisher( f"/wild_visual_navigation_node/{cam}/confidence", Image, - queue_size=10, + queue_size=1, ) self.camera_handler[cam]["conf_pub"] = conf_pub @@ -242,7 +247,7 @@ def setup_ros(self, setup_fully=True): imagefeat_pub = rospy.Publisher( f"/wild_visual_navigation_node/{cam}/feat", ImageFeatures, - queue_size=10, + queue_size=1, ) self.camera_handler[cam]["imagefeat_pub"] = imagefeat_pub @@ -255,121 +260,148 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo info_msg (sensor_msgs/CameraInfo): Camera info message associated to the image cam (str): Camera name """ - if self.ros_params.verbose: - # DEBUG Logging - self.log_data[f"nr_images_{cam}"] += 1 - self.log_data[f"time_last_image_{cam}"] = rospy.get_time() - - # Update model from file if possible - self.load_model() - - # Convert image message to torch image - torch_image = rc.ros_image_to_torch(image_msg, device=self.ros_params.device) - torch_image = self.camera_handler[cam]["image_projector"].resize_image(torch_image) - C, H, W = torch_image.shape - - _, feat, seg, center, dense_feat = self.feature_extractor.extract( - img=torch_image[None], - return_centers=False, - return_dense_features=True, - n_random_pixels=100, - ) - if self.ros_params.prediction_per_pixel: - # Evaluate traversability - data = Data(x=dense_feat[0].permute(1, 2, 0).reshape(-1, dense_feat.shape[1])) - else: - # input_feat = dense_feat[0].permute(1, 2, 0).reshape(-1, dense_feat.shape[1]) - input_feat = feat[seg.reshape(-1)] - data = Data(x=input_feat) + if self.ros_params.verbose: + print(f"[{self._node_name}] Image callback: {cam}... ", end="") + try: + # Run the callback so as to match the desired rate + ts = image_msg.header.stamp.to_sec() + if abs(ts - self.last_image_ts) < 1.0 / self.ros_params.image_callback_rate: + if self.ros_params.verbose: + print(f"skip") + return + else: + if self.ros_params.verbose: + print(f"process") + self.last_image_ts = ts - # Evaluate traversability - prediction = self.model.forward(data) + if self.ros_params.verbose: + # DEBUG Logging + self.log_data[f"nr_images_{cam}"] += 1 + self.log_data[f"time_last_image_{cam}"] = rospy.get_time() + + # Update model from file if possible + self.load_model() + + # Convert image message to torch image + torch_image = rc.ros_image_to_torch(image_msg, device=self.ros_params.device) + torch_image = self.camera_handler[cam]["image_projector"].resize_image(torch_image) + C, H, W = torch_image.shape + + # Extract features + _, feat, seg, center, dense_feat = self.feature_extractor.extract( + img=torch_image[None], + return_centers=False, + return_dense_features=True, + n_random_pixels=100, + ) - if not self.anomaly_detection: - out_trav = prediction.reshape(H, W, -1)[:, :, 0] - - # Publish traversability - if self.ros_params.scale_traversability: - # Apply piecewise linear scaling 0->0; threshold->0.5; 1->1 - traversability = out_trav.clone() - m = traversability < self.ros_params.traversability_threshold - # Scale untraversable - traversability[m] *= 0.5 / self.ros_params.traversability_threshold - # Scale traversable - traversability[~m] -= self.ros_params.traversability_threshold - traversability[~m] *= 0.5 / (1 - self.ros_params.traversability_threshold) - traversability[~m] += 0.5 - traversability = traversability.clip(0, 1) - # TODO Check if this was a bug - out_trav = traversability - else: - loss, loss_aux, trav = self.traversability_loss(None, prediction) + # Forward pass to predict traversability + if self.ros_params.prediction_per_pixel: + # Pixel-wise traversability prediction using the dense features + data = Data(x=dense_feat[0].permute(1, 2, 0).reshape(-1, dense_feat.shape[1])) + else: + # input_feat = dense_feat[0].permute(1, 2, 0).reshape(-1, dense_feat.shape[1]) + # Segment-wise traversability prediction using the average feature per segment + input_feat = feat[seg.reshape(-1)] + data = Data(x=input_feat) + + # Predict traversability per feature + prediction = self.model.forward(data) + + if not self.anomaly_detection: + out_trav = prediction.reshape(H, W, -1)[:, :, 0] + + # Publish traversability + if self.ros_params.scale_traversability: + # Apply piecewise linear scaling 0->0; threshold->0.5; 1->1 + traversability = out_trav.clone() + m = traversability < self.ros_params.traversability_threshold + # Scale untraversable + traversability[m] *= 0.5 / self.ros_params.traversability_threshold + # Scale traversable + traversability[~m] -= self.ros_params.traversability_threshold + traversability[~m] *= 0.5 / (1 - self.ros_params.traversability_threshold) + traversability[~m] += 0.5 + traversability = traversability.clip(0, 1) + # TODO Check if this was a bug + out_trav = traversability + else: + loss, loss_aux, trav = self.anomaly_loss(None, prediction) - out_trav = trav.reshape(H, W, -1)[:, :, 0] + out_trav = trav.reshape(H, W, -1)[:, :, 0] - # Clip to binary output - if self.ros_params.clip_to_binary: - out_trav = torch.where( - out_trav.squeeze() <= self.ros_params.traversability_threshold, - 0.0, - 1.0, - ) + # Clip to binary output + if self.ros_params.clip_to_binary: + out_trav = torch.where( + out_trav.squeeze() <= self.ros_params.traversability_threshold, + 0.0, + 1.0, + ) - msg = rc.numpy_to_ros_image(out_trav.cpu().numpy(), "passthrough") - msg.header = image_msg.header - msg.width = out_trav.shape[0] - msg.height = out_trav.shape[1] - self.camera_handler[cam]["trav_pub"].publish(msg) - - msg = self.camera_handler[cam]["camera_info_msg_out"] - msg.header = image_msg.header - self.camera_handler[cam]["info_pub"].publish(msg) - - # Publish image - if self.ros_params.camera_topics[cam]["publish_input_image"]: - msg = rc.numpy_to_ros_image( - (torch_image.permute(1, 2, 0) * 255).cpu().numpy().astype(np.uint8), - "rgb8", - ) - msg.header = image_msg.header - msg.width = torch_image.shape[1] - msg.height = torch_image.shape[2] - self.camera_handler[cam]["input_pub"].publish(msg) - - # Publish confidence - if self.ros_params.camera_topics[cam]["publish_confidence"]: - loss_reco = F.mse_loss(prediction[:, 1:], data.x, reduction="none").mean(dim=1) - confidence = self.confidence_generator.inference_without_update(x=loss_reco) - out_confidence = confidence.reshape(H, W) - msg = rc.numpy_to_ros_image(out_confidence.cpu().numpy(), "passthrough") + msg = rc.numpy_to_ros_image(out_trav.cpu().numpy(), "passthrough") msg.header = image_msg.header - msg.width = out_confidence.shape[0] - msg.height = out_confidence.shape[1] - self.camera_handler[cam]["conf_pub"].publish(msg) + msg.width = out_trav.shape[0] + msg.height = out_trav.shape[1] + self.camera_handler[cam]["trav_pub"].publish(msg) - # Publish features and feature_segments - if self.ros_params.camera_topics[cam]["use_for_training"]: - msg = ImageFeatures() + msg = self.camera_handler[cam]["camera_info_msg_out"] msg.header = image_msg.header - msg.feature_segments = rc.numpy_to_ros_image(seg.cpu().numpy().astype(np.int32), "passthrough") - msg.feature_segments.header = image_msg.header - feat_np = feat.cpu().numpy() + self.camera_handler[cam]["info_pub"].publish(msg) - mad1 = MultiArrayDimension() - mad1.label = "n" - mad1.size = feat_np.shape[0] - mad1.stride = feat_np.shape[0] * feat_np.shape[1] + # Publish image + if self.ros_params.camera_topics[cam]["publish_input_image"]: + msg = rc.numpy_to_ros_image( + (torch_image.permute(1, 2, 0) * 255).cpu().numpy().astype(np.uint8), + "rgb8", + ) + msg.header = image_msg.header + msg.width = torch_image.shape[1] + msg.height = torch_image.shape[2] + self.camera_handler[cam]["input_pub"].publish(msg) - mad2 = MultiArrayDimension() - mad2.label = "feat" - mad2.size = feat_np.shape[1] - mad2.stride = feat_np.shape[1] + # Publish confidence + if self.ros_params.camera_topics[cam]["publish_confidence"]: + loss_reco = F.mse_loss(prediction[:, 1:], data.x, reduction="none").mean(dim=1) + confidence = self.confidence_generator.inference_without_update(x=loss_reco) + out_confidence = confidence.reshape(H, W) + msg = rc.numpy_to_ros_image(out_confidence.cpu().numpy(), "passthrough") + msg.header = image_msg.header + msg.width = out_confidence.shape[0] + msg.height = out_confidence.shape[1] + self.camera_handler[cam]["conf_pub"].publish(msg) + + # Publish features and feature_segments + if self.ros_params.camera_topics[cam]["use_for_training"]: + msg = ImageFeatures() + msg.header = image_msg.header + msg.feature_segments = rc.numpy_to_ros_image(seg.cpu().numpy().astype(np.int32), "passthrough") + msg.feature_segments.header = image_msg.header + feat_np = feat.cpu().numpy() + + mad1 = MultiArrayDimension() + mad1.label = "n" + mad1.size = feat_np.shape[0] + mad1.stride = feat_np.shape[0] * feat_np.shape[1] + + mad2 = MultiArrayDimension() + mad2.label = "feat" + mad2.size = feat_np.shape[1] + mad2.stride = feat_np.shape[1] + + msg.features.data = feat_np.flatten().tolist() + msg.features.layout.dim.append(mad1) + msg.features.layout.dim.append(mad2) + self.camera_handler[cam]["imagefeat_pub"].publish(msg) - msg.features.data = feat_np.flatten().tolist() - msg.features.layout.dim.append(mad1) - msg.features.layout.dim.append(mad2) - self.camera_handler[cam]["imagefeat_pub"].publish(msg) + except Exception as e: + traceback.print_exc() + rospy.logerr(f"[{self._node_name}] error image callback", e) + self.system_events["image_callback_state"] = { + "time": rospy.get_time(), + "value": f"failed to execute {e}", + } + raise Exception("Error in image callback") def load_model(self): """Method to load the new model weights to perform inference on the incoming images @@ -378,31 +410,31 @@ def load_model(self): None """ try: - self._load_model_counter += 1 - if self._load_model_counter % 10 == 0: - new_model_state_dict = torch.load(f"{WVN_ROOT_DIR}/.tmp_state_dict.pt") - k = list(self.model.state_dict().keys())[-1] + # self._load_model_counter += 1 + # if self._load_model_counter % 10 == 0: + new_model_state_dict = torch.load(f"{WVN_ROOT_DIR}/.tmp_state_dict.pt") + k = list(self.model.state_dict().keys())[-1] - if (self.model.state_dict()[k] != new_model_state_dict[k]).any(): - if self.ros_params.verbose: - self.log_data[f"time_last_model"] = rospy.get_time() - self.log_data[f"nr_model_updates"] += 1 + if (self.model.state_dict()[k] != new_model_state_dict[k]).any(): + if self.ros_params.verbose: + self.log_data[f"time_last_model"] = rospy.get_time() + self.log_data[f"nr_model_updates"] += 1 - self.model.load_state_dict(new_model_state_dict, strict=False) - - try: - if new_model_state_dict["traversability_threshold"] is not None: - # TODO Verify if this works or the writing is need - self.ros_params.traversability_threshold = new_model_state_dict["traversability_threshold"] - if new_model_state_dict["confidence_generator"] is not None: - self.confidence_generator_state = new_model_state_dict["confidence_generator"] + self.model.load_state_dict(new_model_state_dict, strict=False) + try: + if new_model_state_dict["traversability_threshold"] is not None: + # TODO Verify if this works or the writing is need + self.ros_params.traversability_threshold = new_model_state_dict["traversability_threshold"] + if new_model_state_dict["confidence_generator"] is not None: self.confidence_generator_state = new_model_state_dict["confidence_generator"] - self.confidence_generator.var = self.confidence_generator_state["var"] - self.confidence_generator.mean = self.confidence_generator_state["mean"] - self.confidence_generator.std = self.confidence_generator_state["std"] - except Exception: - pass + + self.confidence_generator_state = new_model_state_dict["confidence_generator"] + self.confidence_generator.var = self.confidence_generator_state["var"] + self.confidence_generator.mean = self.confidence_generator_state["mean"] + self.confidence_generator.std = self.confidence_generator_state["std"] + except Exception: + pass except Exception as e: if self.ros_params.verbose: diff --git a/wild_visual_navigation_ros/scripts/wvn_learning_node.py b/wild_visual_navigation_ros/scripts/wvn_learning_node.py index 6cbbee14..3fbae134 100644 --- a/wild_visual_navigation_ros/scripts/wvn_learning_node.py +++ b/wild_visual_navigation_ros/scripts/wvn_learning_node.py @@ -77,7 +77,6 @@ def __init__(self, node_name): vis_node_index=self.ros_params.vis_node_index, mode=self.ros_params.mode, extraction_store_folder=self.ros_params.extraction_store_folder, - patch_size=self.ros_params.dino_patch_size, scale_traversability=self.ros_params.scale_traversability, anomaly_detection=self.anomaly_detection, ) @@ -173,7 +172,6 @@ def learning_thread_loop(self): """ # Set rate rate = rospy.Rate(self.ros_params.learning_thread_rate) - i = 0 # Learning loop while True: self.system_events["learning_thread_loop"] = { @@ -204,35 +202,33 @@ def learning_thread_loop(self): system_state.step = self.step self.pub_system_state.publish(system_state) - rate.sleep() - if i % 10 == 0: - res = self.traversability_estimator._model.state_dict() - - # Compute ROC Threshold - if self.ros_params.scale_traversability: - if self.traversability_estimator._auxiliary_training_roc._update_count != 0: - try: - ( - fpr, - tpr, - thresholds, - ) = self.traversability_estimator._auxiliary_training_roc.compute() - index = torch.where(fpr > self.ros_params.scale_traversability_max_fpr)[0][0] - traversability_threshold = thresholds[index] - except Exception: - traversability_threshold = 0.5 - else: + # Get current weights + new_model_state_dict = self.traversability_estimator._model.state_dict() + + # Compute ROC Threshold + if self.ros_params.scale_traversability: + if self.traversability_estimator._auxiliary_training_roc._update_count != 0: + try: + ( + fpr, + tpr, + thresholds, + ) = self.traversability_estimator._auxiliary_training_roc.compute() + index = torch.where(fpr > self.ros_params.scale_traversability_max_fpr)[0][0] + traversability_threshold = thresholds[index] + except Exception: traversability_threshold = 0.5 + else: + traversability_threshold = 0.5 - res["traversability_threshold"] = traversability_threshold - cg = self.traversability_estimator._traversability_loss._confidence_generator - res["confidence_generator"] = cg.get_dict() + new_model_state_dict["traversability_threshold"] = traversability_threshold + cg = self.traversability_estimator._traversability_loss._confidence_generator + new_model_state_dict["confidence_generator"] = cg.get_dict() - os.remove( - f"{WVN_ROOT_DIR}/.tmp_state_dict.pt", - ) - torch.save(res, f"{WVN_ROOT_DIR}/.tmp_state_dict.pt") - i += 1 + os.remove(f"{WVN_ROOT_DIR}/.tmp_state_dict.pt") + torch.save(new_model_state_dict, f"{WVN_ROOT_DIR}/.tmp_state_dict.pt") + + rate.sleep() self.system_events["learning_thread_loop"] = { "time": rospy.get_time(), @@ -528,9 +524,9 @@ def robot_state_callback(self, state_msg, desired_twist_msg: TwistStamped): try: ts = state_msg.header.stamp.to_sec() if abs(ts - self.last_supervision_ts) < 1.0 / self.ros_params.supervision_callback_rate: - self.system_events["robot_state_callback_cancled"] = { + self.system_events["robot_state_callback_canceled"] = { "time": rospy.get_time(), - "value": "cancled due to rate", + "value": "canceled due to rate", } return self.last_propio_ts = ts @@ -545,9 +541,9 @@ def robot_state_callback(self, state_msg, desired_twist_msg: TwistStamped): device=self.ros_params.device, ) if not success: - self.system_events["robot_state_callback_cancled"] = { + self.system_events["robot_state_callback_canceled"] = { "time": rospy.get_time(), - "value": "cancled due to pose_base_in_world", + "value": "canceled due to pose_base_in_world", } return @@ -644,8 +640,10 @@ def imagefeat_callback(self, *args): "time": rospy.get_time(), "value": "message received", } + if self.ros_params.verbose: print(f"[{self._node_name}] Image callback: {camera_options['name']}... ", end="") + try: # Run the callback so as to match the desired rate ts = imagefeat_msg.header.stamp.to_sec() @@ -668,11 +666,12 @@ def imagefeat_callback(self, *args): device=self.ros_params.device, ) if not success: - self.system_events["image_callback_cancled"] = { + self.system_events["image_callback_canceled"] = { "time": rospy.get_time(), - "value": "cancled due to pose_base_in_world", + "value": "canceled due to pose_base_in_world", } return + success, pose_cam_in_base = rc.ros_tf_to_torch( self.query_tf( self.ros_params.base_frame, @@ -681,13 +680,13 @@ def imagefeat_callback(self, *args): ), device=self.ros_params.device, ) - if not success: - self.system_events["image_callback_cancled"] = { + self.system_events["image_callback_canceled"] = { "time": rospy.get_time(), "value": "canceled due to pose_cam_in_base", } return + # Prepare image projector K, H, W = rc.ros_cam_info_to_tensors(info_msg, device=self.ros_params.device) image_projector = ImageProjector(