diff --git a/basic_pitch/data/datasets/guitarset.py b/basic_pitch/data/datasets/guitarset.py index cbe89bf..978b275 100644 --- a/basic_pitch/data/datasets/guitarset.py +++ b/basic_pitch/data/datasets/guitarset.py @@ -136,17 +136,19 @@ def create_input_data( if seed: random.seed(seed) - def determine_split() -> str: - partition = random.uniform(0, 1) - if partition < validation_bound: + def determine_split(index: int) -> str: + if index < len(track_ids) * validation_bound: return "train" - if partition < test_bound: + elif index < len(track_ids) * test_bound: return "validation" - return "test" + else: + return "test" guitarset = mirdata.initialize("guitarset") + track_ids = guitarset.track_ids + random.shuffle(track_ids) - return [(track_id, determine_split()) for track_id in guitarset.track_ids] + return [(track_id, determine_split(i)) for i, track_id in enumerate(track_ids)] def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None: diff --git a/basic_pitch/data/datasets/ikala.py b/basic_pitch/data/datasets/ikala.py index b56954d..6ed23d4 100644 --- a/basic_pitch/data/datasets/ikala.py +++ b/basic_pitch/data/datasets/ikala.py @@ -138,21 +138,17 @@ def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str def create_input_data(train_percent: float, seed: Optional[int] = None) -> List[Tuple[str, str]]: assert train_percent < 1.0, "Don't over allocate the data!" - # Test percent is 1 - train - validation - validation_bound = train_percent - if seed: random.seed(seed) - def determine_split() -> str: - partition = random.uniform(0, 1) - if partition < validation_bound: - return "train" - return "validation" - ikala = mirdata.initialize("ikala") + track_ids = ikala.track_ids + random.shuffle(track_ids) + + def determine_split(index: int) -> str: + return "train" if index < len(track_ids) * train_percent else "validation" - return [(track_id, determine_split()) for track_id in ikala.track_ids] + return [(track_id, determine_split(i)) for i, track_id in enumerate(track_ids)] def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None: diff --git a/basic_pitch/data/datasets/medleydb_pitch.py b/basic_pitch/data/datasets/medleydb_pitch.py index 891afa4..c7083ce 100644 --- a/basic_pitch/data/datasets/medleydb_pitch.py +++ b/basic_pitch/data/datasets/medleydb_pitch.py @@ -136,22 +136,17 @@ def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str def create_input_data(train_percent: float, seed: Optional[int] = None) -> List[Tuple[str, str]]: assert train_percent < 1.0, "Don't over allocate the data!" - # Test percent is 1 - train - validation - validation_bound = train_percent - if seed: random.seed(seed) - def determine_split() -> str: - partition = random.uniform(0, 1) - if partition < validation_bound: - return "train" - return "validation" - medleydb_pitch = mirdata.initialize("medleydb_pitch") - medleydb_pitch.download() + track_ids = medleydb_pitch.track_ids + random.shuffle(track_ids) + + def determine_split(index: int) -> str: + return "train" if index < len(track_ids) * train_percent else "validation" - return [(track_id, determine_split()) for track_id in medleydb_pitch.track_ids] + return [(track_id, determine_split(i)) for i, track_id in enumerate(track_ids)] def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None: diff --git a/tests/data/test_medleydb_pitch.py b/tests/data/test_medleydb_pitch.py index 5c056f2..f2c7b23 100644 --- a/tests/data/test_medleydb_pitch.py +++ b/tests/data/test_medleydb_pitch.py @@ -54,7 +54,7 @@ def test_medleydb_pitch_invalid_tracks(tmpdir: str) -> None: def test_medleydb_create_input_data() -> None: data = create_input_data(train_percent=0.5) data.sort(key=lambda el: el[1]) # sort by split - tolerance = 0.05 + tolerance = 0.01 for _, group in itertools.groupby(data, lambda el: el[1]): assert (0.5 - tolerance) * len(data) <= len(list(group)) <= (0.5 + tolerance) * len(data)