Skip to content

Commit

Permalink
added two tests for example_deserialization, and made some correspond…
Browse files Browse the repository at this point in the history
…ing changes in the original file.
  • Loading branch information
bgenchel committed Aug 9, 2024
1 parent ffb2db4 commit 6e3710d
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 12 deletions.
6 changes: 3 additions & 3 deletions basic_pitch/data/tf_example_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def transcription_file_generator(
dataset_names: List[str],
datasets_base_path: str,
sample_weights: np.ndarray,
) -> Tuple[Callable[[], Iterator[str]], bool]:
) -> Tuple[Callable[[], Iterator[tf.Tensor]], bool]:
"""
dataset_names: list of dataset dataset_names
"""
Expand All @@ -235,7 +235,7 @@ def transcription_file_generator(
return lambda: _validation_file_generator(file_dict), True


def _train_file_generator(x: Dict[str, List[str]], weights: np.ndarray) -> Iterator[str]:
def _train_file_generator(x: Dict[str, tf.data.Dataset], weights: np.ndarray) -> Iterator[tf.Tensor]:
x = {k: list(v) for (k, v) in x.items()}
keys = list(x.keys())
# shuffle each list
Expand All @@ -248,7 +248,7 @@ def _train_file_generator(x: Dict[str, List[str]], weights: np.ndarray) -> Itera
yield fpath


def _validation_file_generator(x: Dict[str, tf.data.Dataset]) -> Iterator[str]:
def _validation_file_generator(x: Dict[str, tf.data.Dataset]) -> Iterator[tf.Tensor]:
x = {k: list(v) for (k, v) in x.items()}
# loop until there are no more test files
while any(x.values()):
Expand Down
67 changes: 58 additions & 9 deletions tests/data/test_tf_example_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,73 @@
# limitations under the License.

import numpy as np
import os
import pathlib
import tensorflow as tf


from basic_pitch.data.tf_example_deserialization import sample_datasets, transcription_file_generator
from basic_pitch.data.tf_example_deserialization import transcription_dataset, transcription_file_generator


def test_prepare_dataset():
pass
def create_empty_tfrecord(filepath: pathlib.Path) -> None:
assert filepath.suffix == ".tfrecord"
with tf.io.TFRecordWriter(str(filepath)) as writer:
writer.write("")


def test_sample_datasets():
pass
# def test_prepare_dataset() -> None:
# pass


def test_transcription_file_generator(tmpdir: str):
print("FUCK YOU ")
file_gen, random_seed = transcription_file_generator("train", ["test2"], datasets_base_path=tmpdir, sample_weights=np.ndarray(1))
# def test_sample_datasets() -> None:
# pass


# def test_transcription_dataset(tmp_path: pathlib.Path) -> None:
# dataset_path = tmp_path / "test_ds" / "splits" / "train"
# dataset_path.mkdir(parents=True)
# create_empty_tfrecord(dataset_path / "test.tfrecord")

# file_gen, random_seed = transcription_file_generator(
# "train", ["test_ds"], datasets_base_path=str(tmp_path), sample_weights=np.array([1])
# )

# transcription_dataset(file_generator=file_gen, n_samples_per_track=1, random_seed=random_seed)


def test_transcription_file_generator_train(tmp_path: pathlib.Path) -> None:
dataset_path = tmp_path / "test_ds" / "splits" / "train"
dataset_path.mkdir(parents=True)
create_empty_tfrecord(dataset_path / "test.tfrecord")

file_gen, random_seed = transcription_file_generator(
"train", ["test_ds"], datasets_base_path=str(tmp_path), sample_weights=np.array([1])
)

assert random_seed is False

print(file_gen())
generator = file_gen()
assert next(generator).numpy().decode("utf-8") == str(dataset_path / "test.tfrecord")
try:
next(generator)
except Exception as e:
assert isinstance(e, StopIteration)


def test_transcription_file_generator_valid(tmp_path: pathlib.Path) -> None:
dataset_path = tmp_path / "test_ds" / "splits" / "valid"
dataset_path.mkdir(parents=True)
create_empty_tfrecord(dataset_path / "test.tfrecord")

file_gen, random_seed = transcription_file_generator(
"valid", ["test_ds"], datasets_base_path=str(tmp_path), sample_weights=np.array([1])
)

assert random_seed is True

generator = file_gen()
assert next(generator).numpy().decode("utf-8") == str(dataset_path / "test.tfrecord")
try:
next(generator)
except Exception as e:
assert isinstance(e, StopIteration)

0 comments on commit 6e3710d

Please sign in to comment.