diff --git a/data/SimpleTestData/0-10.npz b/data/model_eval/SimpleTestData/0-10.npz similarity index 100% rename from data/SimpleTestData/0-10.npz rename to data/model_eval/SimpleTestData/0-10.npz diff --git a/data/SimpleTestData/1-10.npz b/data/model_eval/SimpleTestData/1-10.npz similarity index 100% rename from data/SimpleTestData/1-10.npz rename to data/model_eval/SimpleTestData/1-10.npz diff --git a/data/SimpleTestData/10-100.npz b/data/model_eval/SimpleTestData/10-100.npz similarity index 100% rename from data/SimpleTestData/10-100.npz rename to data/model_eval/SimpleTestData/10-100.npz diff --git a/data/SimpleTestData/11-100.npz b/data/model_eval/SimpleTestData/11-100.npz similarity index 100% rename from data/SimpleTestData/11-100.npz rename to data/model_eval/SimpleTestData/11-100.npz diff --git a/data/SimpleTestData/12-100.npz b/data/model_eval/SimpleTestData/12-100.npz similarity index 100% rename from data/SimpleTestData/12-100.npz rename to data/model_eval/SimpleTestData/12-100.npz diff --git a/data/SimpleTestData/13-100.npz b/data/model_eval/SimpleTestData/13-100.npz similarity index 100% rename from data/SimpleTestData/13-100.npz rename to data/model_eval/SimpleTestData/13-100.npz diff --git a/data/SimpleTestData/14-100.npz b/data/model_eval/SimpleTestData/14-100.npz similarity index 100% rename from data/SimpleTestData/14-100.npz rename to data/model_eval/SimpleTestData/14-100.npz diff --git a/data/SimpleTestData/15-100.npz b/data/model_eval/SimpleTestData/15-100.npz similarity index 100% rename from data/SimpleTestData/15-100.npz rename to data/model_eval/SimpleTestData/15-100.npz diff --git a/data/SimpleTestData/16-100.npz b/data/model_eval/SimpleTestData/16-100.npz similarity index 100% rename from data/SimpleTestData/16-100.npz rename to data/model_eval/SimpleTestData/16-100.npz diff --git a/data/SimpleTestData/17-100.npz b/data/model_eval/SimpleTestData/17-100.npz similarity index 100% rename from data/SimpleTestData/17-100.npz rename to data/model_eval/SimpleTestData/17-100.npz diff --git a/data/SimpleTestData/18-100.npz b/data/model_eval/SimpleTestData/18-100.npz similarity index 100% rename from data/SimpleTestData/18-100.npz rename to data/model_eval/SimpleTestData/18-100.npz diff --git a/data/SimpleTestData/19-100.npz b/data/model_eval/SimpleTestData/19-100.npz similarity index 100% rename from data/SimpleTestData/19-100.npz rename to data/model_eval/SimpleTestData/19-100.npz diff --git a/data/SimpleTestData/2-100.npz b/data/model_eval/SimpleTestData/2-100.npz similarity index 100% rename from data/SimpleTestData/2-100.npz rename to data/model_eval/SimpleTestData/2-100.npz diff --git a/data/SimpleTestData/20-100.npz b/data/model_eval/SimpleTestData/20-100.npz similarity index 100% rename from data/SimpleTestData/20-100.npz rename to data/model_eval/SimpleTestData/20-100.npz diff --git a/data/SimpleTestData/21-100.npz b/data/model_eval/SimpleTestData/21-100.npz similarity index 100% rename from data/SimpleTestData/21-100.npz rename to data/model_eval/SimpleTestData/21-100.npz diff --git a/data/SimpleTestData/22-10-FastSAM-x.npz b/data/model_eval/SimpleTestData/22-10-FastSAM-x.npz similarity index 100% rename from data/SimpleTestData/22-10-FastSAM-x.npz rename to data/model_eval/SimpleTestData/22-10-FastSAM-x.npz diff --git a/data/SimpleTestData/23-10-FastSAM-x.npz b/data/model_eval/SimpleTestData/23-10-FastSAM-x.npz similarity index 100% rename from data/SimpleTestData/23-10-FastSAM-x.npz rename to data/model_eval/SimpleTestData/23-10-FastSAM-x.npz diff --git a/data/SimpleTestData/24-10-FastSAM-x.npz b/data/model_eval/SimpleTestData/24-10-FastSAM-x.npz similarity index 100% rename from data/SimpleTestData/24-10-FastSAM-x.npz rename to data/model_eval/SimpleTestData/24-10-FastSAM-x.npz diff --git a/data/SimpleTestData/25-10-FastSAM-x.npz b/data/model_eval/SimpleTestData/25-10-FastSAM-x.npz similarity index 100% rename from data/SimpleTestData/25-10-FastSAM-x.npz rename to data/model_eval/SimpleTestData/25-10-FastSAM-x.npz diff --git a/data/SimpleTestData/26-10-FastSAM-x.npz b/data/model_eval/SimpleTestData/26-10-FastSAM-x.npz similarity index 100% rename from data/SimpleTestData/26-10-FastSAM-x.npz rename to data/model_eval/SimpleTestData/26-10-FastSAM-x.npz diff --git a/data/SimpleTestData/27-10-FastSAM-x.npz b/data/model_eval/SimpleTestData/27-10-FastSAM-x.npz similarity index 100% rename from data/SimpleTestData/27-10-FastSAM-x.npz rename to data/model_eval/SimpleTestData/27-10-FastSAM-x.npz diff --git a/data/SimpleTestData/28-100-FastSAM-x.npz b/data/model_eval/SimpleTestData/28-100-FastSAM-x.npz similarity index 100% rename from data/SimpleTestData/28-100-FastSAM-x.npz rename to data/model_eval/SimpleTestData/28-100-FastSAM-x.npz diff --git a/data/SimpleTestData/29-100-FastSAM-x.npz b/data/model_eval/SimpleTestData/29-100-FastSAM-x.npz similarity index 100% rename from data/SimpleTestData/29-100-FastSAM-x.npz rename to data/model_eval/SimpleTestData/29-100-FastSAM-x.npz diff --git a/data/SimpleTestData/3-100.npz b/data/model_eval/SimpleTestData/3-100.npz similarity index 100% rename from data/SimpleTestData/3-100.npz rename to data/model_eval/SimpleTestData/3-100.npz diff --git a/data/SimpleTestData/30-100-FastSAM-x.npz b/data/model_eval/SimpleTestData/30-100-FastSAM-x.npz similarity index 100% rename from data/SimpleTestData/30-100-FastSAM-x.npz rename to data/model_eval/SimpleTestData/30-100-FastSAM-x.npz diff --git a/data/SimpleTestData/31-100-FastSAM-x.npz b/data/model_eval/SimpleTestData/31-100-FastSAM-x.npz similarity index 100% rename from data/SimpleTestData/31-100-FastSAM-x.npz rename to data/model_eval/SimpleTestData/31-100-FastSAM-x.npz diff --git a/data/SimpleTestData/32-100-FastSAM-x.npz b/data/model_eval/SimpleTestData/32-100-FastSAM-x.npz similarity index 100% rename from data/SimpleTestData/32-100-FastSAM-x.npz rename to data/model_eval/SimpleTestData/32-100-FastSAM-x.npz diff --git a/data/SimpleTestData/4-100.npz b/data/model_eval/SimpleTestData/4-100.npz similarity index 100% rename from data/SimpleTestData/4-100.npz rename to data/model_eval/SimpleTestData/4-100.npz diff --git a/data/SimpleTestData/5-100.npz b/data/model_eval/SimpleTestData/5-100.npz similarity index 100% rename from data/SimpleTestData/5-100.npz rename to data/model_eval/SimpleTestData/5-100.npz diff --git a/data/SimpleTestData/6-100.npz b/data/model_eval/SimpleTestData/6-100.npz similarity index 100% rename from data/SimpleTestData/6-100.npz rename to data/model_eval/SimpleTestData/6-100.npz diff --git a/data/SimpleTestData/7-100.npz b/data/model_eval/SimpleTestData/7-100.npz similarity index 100% rename from data/SimpleTestData/7-100.npz rename to data/model_eval/SimpleTestData/7-100.npz diff --git a/data/SimpleTestData/8-100.npz b/data/model_eval/SimpleTestData/8-100.npz similarity index 100% rename from data/SimpleTestData/8-100.npz rename to data/model_eval/SimpleTestData/8-100.npz diff --git a/data/SimpleTestData/9-100.npz b/data/model_eval/SimpleTestData/9-100.npz similarity index 100% rename from data/SimpleTestData/9-100.npz rename to data/model_eval/SimpleTestData/9-100.npz diff --git a/src/data_collection/data_loader.py b/src/data_collection/data_loader.py index c7342dc..e9ac2a0 100644 --- a/src/data_collection/data_loader.py +++ b/src/data_collection/data_loader.py @@ -75,6 +75,7 @@ def sample(self, batch_size: int, time_steps: int) -> Tuple[torch.Tensor, torch. states_tensor = states_tensor .reshape((-1, 12, 128, 128)) masks_tensor = torch.from_numpy(np.array(masks))[:, :self.num_obj] - masks_tensor = F.one_hot(masks_tensor, num_classes=self.num_obj).float()[1:] # get rid of background + masks_tensor = F.one_hot(masks_tensor.long(), num_classes=self.num_obj + 1).float()[1:] # get rid of background [B, H, W, O] + masks_tensor = masks_tensor.permute(0, 1, 4, 2, 3) # [B, O, H, W] return states_tensor, object_bounding_boxes_tensor, masks_tensor, torch.from_numpy(np.array(actions)) diff --git a/src/scripts/train_prediction.py b/src/scripts/train_prediction.py index 3a5b37e..b6b2301 100644 --- a/src/scripts/train_prediction.py +++ b/src/scripts/train_prediction.py @@ -17,7 +17,7 @@ def train(config: DictConfig, batch_size: int = 4, t_steps: int = 1, num_obj: in device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") - data_loader = DataLoader("SimpleTestDataSmall", num_obj) + data_loader = DataLoader("SimpleTestData", num_obj) feature_extract = FeatureExtractor(num_objects=num_obj).to(device) predictor = Predictor(num_layers=1, time_steps=t_steps).to(device)