Skip to content

Commit

Permalink
add local download command to datasets, add command for downloading d…
Browse files Browse the repository at this point in the history
…atasets by name to setup.cfg, update repo name for import in test_pipeline, add cache files to gitignore, modified dataset loading in train.py to use lowercase names for easier / more uniform writing.
  • Loading branch information
bgenchel committed Sep 22, 2023
1 parent 8bc8a26 commit 37db981
Show file tree
Hide file tree
Showing 12 changed files with 106 additions and 79 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,7 @@ junit.xml

# intellij
.idea/

# cache
*_set_cache_*.index
*_set_cache_*.data*
35 changes: 14 additions & 21 deletions basic_pitch/dataset/commandline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,28 @@
# limitations under the License.

import argparse
import inspect
import os
import os.path as op
import pdb


def add_default(parser: argparse.ArgumentParser):
def add_default(parser: argparse.ArgumentParser, dataset_name: str):
parser.add_argument("source", nargs='?', default=op.join(op.expanduser('~'), 'mir_datasets'),
help="Source directory for mir data. Defaults to local mir_datasets folder.")
parser.add_argument("destination", nargs='?',
default=op.join(op.expanduser('~'), 'data', 'basic_pitch', dataset_name),
help="Output directory to write results to. Defaults to local ~/data/basic_pitch/{dataset}/")
parser.add_argument("--runner", choices=["DataflowRunner", "DirectRunner"], default="DirectRunner")
parser.add_argument(
"source",
help="Source directory for mir data.",
)
parser.add_argument(
"destination",
help="Output directory to write results to.",
)
parser.add_argument(
"--timestamped",
default=False,
action="store_true",
help="If passed, the dataset will be put into a timestamp directory instead of 'splits'",
)
parser.add_argument("--timestamped", default=False, action="store_true",
help="If passed, the dataset will be put into a timestamp directory instead of 'splits'")
parser.add_argument("--batch-size", default=5, type=int, help="Number of examples per tfrecord")
parser.add_argument(
"--worker-harness-container-image",
default="",
help="Container image to run dataset generation job with. Required due to non-python dependencies",
)
parser.add_argument("--worker-harness-container-image", default="",
help="Container image to run dataset generation job with. Required due to non-python dependencies")


def resolve_destination(namespace: argparse.Namespace, dataset: str, time_created: int) -> str:
return os.path.join(namespace.destination, dataset, str(time_created) if namespace.timestamped else "splits")
return os.path.join(namespace.destination, str(time_created) if namespace.timestamped else "splits")


def add_split(
Expand Down
33 changes: 33 additions & 0 deletions basic_pitch/dataset/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import argparse

from basic_pitch.dataset import commandline
from basic_pitch.dataset.guitarset import main as guitarset_main
from basic_pitch.dataset.ikala import main as ikala_main
from basic_pitch.dataset.maestro import main as maestro_main
from basic_pitch.dataset.medleydb_pitch import main as medleydb_pitch_main
from basic_pitch.dataset.slakh import main as slakh_main

dataset_dict = {
'guitarset': guitarset_main,
'ikala': ikala_main,
'maestro': maestro_main,
'medleydb_pitch': medleydb_pitch_main,
'slakh': slakh_main
}

def main():
dataset_parser = argparse.ArgumentParser()
dataset_parser.add_argument("dataset", choices=list(dataset_dict.keys()), help="The dataset to download / process.")
dataset = dataset_parser.parse_args().dataset

print(f'got the arg: {dataset}')
cl_parser = argparse.ArgumentParser()
commandline.add_default(cl_parser, dataset)
commandline.add_split(cl_parser)
known_args, pipeline_args = cl_parser.parse_known_args() # sys.argv)

dataset_dict[dataset](known_args, pipeline_args)


if __name__ == '__main__':
main()
24 changes: 12 additions & 12 deletions basic_pitch/dataset/guitarset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import argparse
import logging
import os
import os.path as op
import random
import sys
import time
Expand All @@ -28,7 +29,7 @@

from basic_pitch.dataset import commandline, pipeline

GUITARSET_DIR = "GuitarSet"
GUITARSET_DIR = "guitarset" # "GuitarSet"


class GuitarSetInvalidTracks(beam.DoFn):
Expand All @@ -47,7 +48,7 @@ def setup(self):
import apache_beam as beam
import mirdata

self.guitarset_remote = mirdata.initialize("guitarset", data_home=os.path.join(self.source, "GuitarSet"))
self.guitarset_remote = mirdata.initialize("guitarset", data_home=os.path.join(self.source, GUITARSET_DIR))
self.filesystem = beam.io.filesystems.FileSystems()

def process(self, element: List[str]):
Expand Down Expand Up @@ -144,32 +145,26 @@ def determine_split() -> str:
return "test"

guitarset = mirdata.initialize("guitarset")
guitarset.download()

return [(track_id, determine_split()) for track_id in guitarset.track_ids]


def main():
parser = argparse.ArgumentParser()
commandline.add_default(parser)
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args(sys.argv)

def main(known_args, pipeline_args):
time_created = int(time.time())
destination = commandline.resolve_destination(known_args, GUITARSET_DIR, time_created)
input_data = create_input_data(known_args.train_percent, known_args.validation_percent, known_args.split_seed)

pipeline_options = {
"runner": known_args.runner,
"project": "audio-understanding",
"job_name": f"guitarset-tfrecords-{time_created}",
"region": "europe-west1",
"machine_type": "e2-standard-4",
"num_workers": 25,
"disk_size_gb": 128,
"experiments": ["use_runner_v2"],
"save_main_session": True,
"worker_harness_container_image": known_args.worker_harness_container_image,
}
input_data = create_input_data(known_args.train_percent, known_args.validation_percent, known_args.split_seed)
pipeline.run(
pipeline_options,
input_data,
Expand All @@ -181,4 +176,9 @@ def main():


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
commandline.add_default(parser, op.basename(op.splittext(__file__)[0]))
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args() # parser.parse_known_args(sys.argv)

main(known_args, pipeline_args)
18 changes: 9 additions & 9 deletions basic_pitch/dataset/ikala.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import argparse
import logging
import os
import os.path as op
import random
import sys
import time
Expand Down Expand Up @@ -147,24 +148,18 @@ def determine_split() -> str:
return "validation"

ikala = mirdata.initialize("ikala")
ikala.download()

return [(track_id, determine_split()) for track_id in ikala.track_ids]


def main():
parser = argparse.ArgumentParser()
commandline.add_default(parser)
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args(sys.argv)

def main(known_args, pipeline_args):
time_created = int(time.time())
destination = commandline.resolve_destination(known_args, "iKala", time_created)

pipeline_options = {
"runner": known_args.runner,
"project": "audio-understanding",
"job_name": f"ikala-tfrecords-{time_created}",
"region": "europe-west1",
"machine_type": "e2-standard-4",
"num_workers": 25,
"disk_size_gb": 128,
Expand All @@ -184,4 +179,9 @@ def main():


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
commandline.add_default(parser, op.basename(op.splittext(__file__)[0]))
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args(sys.argv)

main(known_args, pipeline_args)
17 changes: 8 additions & 9 deletions basic_pitch/dataset/maestro.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import argparse
import logging
import os
import os.path as op
import sys
import tempfile
import time
Expand Down Expand Up @@ -193,21 +194,14 @@ def create_input_data(source: str) -> List[Tuple[str, str]]:
return [(track_id, track.split) for track_id, track in maestro.load_tracks().items()]


def main():
parser = argparse.ArgumentParser()
commandline.add_default(parser)

known_args, pipeline_args = parser.parse_known_args(sys.argv)

def main(known_args, pipeline_args):
time_created = int(time.time())
destination = commandline.resolve_destination(known_args, MAESTRO_DIR, time_created)

# TODO: Remove or abstract for foss
pipeline_options = {
"runner": known_args.runner,
"project": "audio-understanding",
"job_name": f"maestro-tfrecords-{time_created}",
"region": "europe-west1",
"machine_type": "e2-highmem-4",
"num_workers": 25,
"disk_size_gb": 128,
Expand All @@ -227,4 +221,9 @@ def main():


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
commandline.add_default(parser, op.basename(op.splittext(__file__)[0]))

known_args, pipeline_args = parser.parse_known_args(sys.argv)

main(known_args, pipeline_args)
18 changes: 9 additions & 9 deletions basic_pitch/dataset/medleydb_pitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import argparse
import logging
import os
import os.path as op
import random
import sys
import time
Expand Down Expand Up @@ -150,24 +151,18 @@ def determine_split() -> str:
return "validation"

medleydb_pitch = mirdata.initialize("medleydb_pitch")
medleydb_pitch.download()

return [(track_id, determine_split()) for track_id in medleydb_pitch.track_ids]


def main():
parser = argparse.ArgumentParser()
commandline.add_default(parser)
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args(sys.argv)

def main(known_args, pipeline_args):
time_created = int(time.time())
destination = commandline.resolve_destination(known_args, "MedleyDB-Pitch", time_created)

pipeline_options = {
"runner": known_args.runner,
"project": "audio-understanding",
"job_name": f"medleydb-pitch-tfrecords-{time_created}",
"region": "europe-west1",
"machine_type": "e2-standard-4",
"num_workers": 25,
"disk_size_gb": 128,
Expand All @@ -187,4 +182,9 @@ def main():


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
commandline.add_default(parser, op.basename(op.splittext(__file__)[0]))
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args(sys.argv)

main(known_args, pipeline_args)
18 changes: 9 additions & 9 deletions basic_pitch/dataset/slakh.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import argparse
import logging
import os
import os.path as op
import sys
import time
from typing import List, Tuple
Expand Down Expand Up @@ -184,24 +185,18 @@ def process(self, element: List[str]):

def create_input_data() -> List[Tuple[str, str]]:
slakh = mirdata.initialize("slakh")
slakh.download()

return [(track_id, track.data_split) for track_id, track in slakh.load_tracks().items()]


def main():
parser = argparse.ArgumentParser()
commandline.add_default(parser)
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args(sys.argv)

def main(known_args, pipeline_args):
time_created = int(time.time())
destination = commandline.resolve_destination(known_args, "slakh", time_created)

pipeline_options = {
"runner": known_args.runner,
"project": "audio-understanding",
"job_name": f"slakh-tfrecords-{time_created}",
"region": "europe-west1",
"machine_type": "e2-standard-4",
"num_workers": 25,
"disk_size_gb": 128,
Expand All @@ -221,4 +216,9 @@ def main():


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
commandline.add_default(parser, op.basename(op.splittext(__file__)[0]))
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args() # sys.argv)

main(known_args, pipeline_args)
5 changes: 1 addition & 4 deletions basic_pitch/dataset/tf_example_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

# import tensorflow_addons as tfa


from basic_pitch.constants import (
ANNOTATIONS_FPS,
ANNOT_N_FRAMES,
Expand Down Expand Up @@ -54,7 +53,6 @@ def prepare_datasets(
training_shuffle_buffer_size : size of shuffle buffer (only for training set)
batch_size : ..
"""

assert batch_size > 0
assert validation_steps is not None and validation_steps > 0
assert training_shuffle_buffer_size is not None
Expand Down Expand Up @@ -170,7 +168,6 @@ def sample_datasets(
pairs=False,
num_parallel_calls=6,
):

assert split in ["train", "validation"]
if split == "validation":
n_shuffle = 0
Expand All @@ -180,6 +177,7 @@ def sample_datasets(

ds_list = []


file_generator, random_seed = transcription_file_generator(
split,
datasets,
Expand Down Expand Up @@ -227,7 +225,6 @@ def transcription_file_generator(
"""
dataset_names: list of dataset dataset_names
"""

file_dict = {
dataset_name: tf.data.Dataset.list_files(
os.path.join(datasets_base_path, dataset_name, "splits", split, "*tfrecord")
Expand Down
Loading

0 comments on commit 37db981

Please sign in to comment.