Skip to content

Commit

Permalink
use multiprocessing manager list
Browse files Browse the repository at this point in the history
  • Loading branch information
thayeral committed Nov 12, 2024
1 parent dba2124 commit 57956fd
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions src/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import sys
from functools import partial
from multiprocessing import Manager
from pathlib import Path

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -530,19 +531,23 @@ def __init__(
self.filename_pattern = filename_pattern
self.cpu_workers = cpu_workers
self.model_input_shape = model_input_shape

self.files = collect_files(
datadir,
modes=self.modes,
samplelimit=self.samplelimit,
embedding=self.embedding,
distribution=self.distribution,
max_amplitude=self.max_amplitude,
photons_range=self.photons_range,
npoints_range=self.npoints_range,
filename_pattern=self.filename_pattern,
cpu_workers=self.cpu_workers
)

manager = Manager()

self.files = manager.list(
collect_files(
datadir,
modes=self.modes,
samplelimit=self.samplelimit,
embedding=self.embedding,
distribution=self.distribution,
max_amplitude=self.max_amplitude,
photons_range=self.photons_range,
npoints_range=self.npoints_range,
filename_pattern=self.filename_pattern,
cpu_workers=self.cpu_workers
))


def __len__(self):
return len(self.files)
Expand All @@ -562,14 +567,15 @@ def __getitem__(self, idx):
defocus_only=self.defocus_only
)
else:
return get_sample(
x, y = get_sample(
path=path,
iotf=self.iotf,
input_coverage=self.input_coverage,
embedding_option=self.embedding_option,
lls_defocus=self.lls_defocus,
defocus_only=self.defocus_only
)
return torch.tensor(x), torch.tensor(y)


class RayDataset:
Expand Down

0 comments on commit 57956fd

Please sign in to comment.