diff --git a/object-locator/data.py b/object-locator/data.py index 4110d8a..32afc29 100644 --- a/object-locator/data.py +++ b/object-locator/data.py @@ -60,6 +60,9 @@ def __init__(self, listfiles = [f for f in listfiles if any(f.lower().endswith(ext) for ext in IMG_EXTENSIONS)] + # Shuffle list of files + random.shuffle(listfiles) + if len(listfiles) == 0: raise ValueError(f"There are no images in '{directory}'") @@ -73,13 +76,16 @@ def __init__(self, self.listfiles = listfiles # Make dataset smaller - self.listfiles = self.listfiles[0:min( - len(self.listfiles), max_dataset_size)] + self.listfiles = self.listfiles[0:min(len(self.listfiles), + max_dataset_size)] # CSV does exist (GT is available) else: self.csv_df = pd.read_csv(os.path.join(directory, csv_filename)) + # Shuffle CSV dataframe + self.csv_df = self.csv_df.sample(frac=1).reset_index(drop=True) + # Make dataset smaller self.csv_df = self.csv_df[0:min( len(self.csv_df), max_dataset_size)] @@ -96,7 +102,7 @@ def __getitem__(self, idx): The second element is a dictionary where the keys are the columns of the CSV. If the CSV did not exist in the dataset directory, the dictionary will only contain the filename of the image. - +_ :param idx: Index of the image in the dataset to get. """ @@ -333,6 +339,9 @@ def __init__(self, listfiles = [f for f in listfiles if any(f.lower().endswith(ext) for ext in IMG_EXTENSIONS)] + # Shuffle list of files + random.shuffle(listfiles) + if len(listfiles) == 0: raise ValueError(f"There are no images in '{directory}'")