Skip to content

Commit

Permalink
code cleanup
Browse files Browse the repository at this point in the history
Former-commit-id: a2dbc9a7c48578865b5c19db6d6a6bea2de73346
  • Loading branch information
Javi Ribera committed Feb 4, 2018
1 parent c66cb5d commit a544881
Showing 1 changed file with 29 additions and 20 deletions.
49 changes: 29 additions & 20 deletions deliverable/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,43 +58,52 @@
if args.cuda:
torch.cuda.manual_seed_all(args.seed)

# Create output directories
os.makedirs(os.path.join(args.out_dir, 'painted'), exist_ok=True)
os.makedirs(os.path.join(args.out_dir, 'est_map'), exist_ok=True)
os.makedirs(os.path.join(args.out_dir, 'est_map_thresholded'), exist_ok=True)


class PlantDataset(data.Dataset):
def __init__(self, root_dir, transform=None, max_dataset_size=np.inf):
"""
Args:
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
max_dataset_size: If the dataset is bigger than this integer,
ignore additional samples.
class CSVDataset(data.Dataset):
def __init__(self, directory, transform=None, max_dataset_size=np.inf):
"""CSVDataset.
The sample images of this dataset must be all inside one directory.
Inside the same directory, there must be one CSV file.
This file must contain one row per image.
It can containas many columns as wanted, i.e, filename, count...
:param directory: Directory with all the images and the CSV file.
:param transform: Transform to be applied to each image.
:param max_dataset_size: Only use the first N images in the directory.
"""

# Get groundtruth from CSV file
csv_filename = None
for filename in os.listdir(root_dir):
for filename in os.listdir(directory):
if filename.endswith('.csv'):
csv_filename = filename
break
if csv_filename is None:
raise ValueError(
'The root directory %s does not have a CSV file with groundtruth' % root_dir)
self.csv_df = pd.read_csv(os.path.join(root_dir, csv_filename))
'The root directory %s does not have a CSV file with groundtruth' % directory)
self.csv_df = pd.read_csv(os.path.join(directory, csv_filename))

# Make the dataset smaller
self.csv_df = self.csv_df[0:min(len(self.csv_df), max_dataset_size)]

self.root_dir = root_dir
self.root_dir = directory
self.transform = transform

def __len__(self):
return len(self.csv_df)

def __getitem__(self, idx):
"""Get one element of the dataset.
Returns a tuple. The first element is the image.
The second element is a dictionary where the keys are the columns of the CSV.
:param idx: Index of the image in the dataset to get.
"""
img_path = os.path.join(self.root_dir, self.csv_df.ix[idx, 0])
img = skimage.io.imread(img_path)
dictionary = dict(self.csv_df.ix[idx])
Expand All @@ -114,13 +123,13 @@ def __getitem__(self, idx):
% args.eval_batch_size)

# Data loading code
testset = PlantDataset(args.test_dir,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)),
]),
max_dataset_size=args.max_testset_size)
testset = CSVDataset(args.test_dir,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)),
]),
max_dataset_size=args.max_testset_size)
testset_loader = data.DataLoader(testset,
batch_size=args.eval_batch_size,
num_workers=args.nThreads)
Expand Down

0 comments on commit a544881

Please sign in to comment.