Skip to content

Commit

Permalink
feat: add clotho dataset, audio web dataset doc
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 11, 2022
1 parent 58793ce commit 146146d
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 9 deletions.
52 changes: 52 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,33 @@ WAVDataset(
)
```


### AudioWebDataset
A [`WebDataset`](https://webdataset.github.io/webdataset/) extension for audio data. Assumes that the `.tar` file comes with pairs of `.wav` (or `.flac`) and `.json` data.
```py
from audio_data_pytorch import AudioWebDataset

dataset = AudioWebDataset(
urls='mywebdataset.tar'
)

waveform, info = next(iter(dataset))

print(waveform.shape) # torch.Size([2, 480000])
print(info.keys()) # dict_keys(['text'])
```

#### Full API:
```py
dataset = AudioWebDataset(
urls: Union[str, Sequence[str]],
transforms: Optional[Callable] = None, # Transforms to apply to audio files
batch_size: Optional[int] = None, # Why is batch_size here? See https://webdataset.github.io/webdataset/gettingstarted/#webdataset-and-dataloader
shuffle: int = 128, # Shuffle in groups of 128
**kwargs, # Forwarded to WebDataset class
)
```

### LJSpeech Dataset
An unsupervised dataset for LJSpeech with voice-only data.
```py
Expand Down Expand Up @@ -129,6 +156,31 @@ dataset = YoutubeDataset(
)
```

### Clotho Dataset
A wrapper for the [Clotho](https://zenodo.org/record/3490684#.Y0VVVOxBwR0) dataset extending `AudioWebDataset`. Requires `pip install py7zr` to decompress `.7z` archive.

```py
from audio_data_pytorch import ClothoDataset, Crop, Stereo, Mono

dataset = ClothoDataset(
root='./data/',
preprocess_sample_rate=48000, # Added to all files during preprocessing
preprocess_transforms=nn.Sequential(Crop(48000*10), Stereo()), # Added to all files during preprocessing
transforms=Mono() # Added dynamically at iteration time
)
```

```py
dataset = ClothoDataset(
root: str, # Path where the dataset is saved
split: str = 'train', # Dataset split, one of: 'train', 'valid'
preprocess_sample_rate: Optional[int] = None, # Preprocesses dataset to this sample rate
preprocess_transforms: Optional[Callable] = None, # Preprocesses dataset with the provided transfomrs
reset: bool = False, # Re-compute preprocessing if `true`
**kwargs # Forwarded to `AudioWebDataset`
)
```


## Transforms

Expand Down
1 change: 1 addition & 0 deletions audio_data_pytorch/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .audio_web_dataset import AudioWebDataset, AudioWebDatasetPreprocess
from .clotho_dataset import ClothoDataset
from .common_voice_dataset import CommonVoiceDataset
from .libri_speech_dataset import LibriSpeechDataset
from .lj_speech_dataset import LJSpeechDataset
Expand Down
65 changes: 57 additions & 8 deletions audio_data_pytorch/datasets/audio_web_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import re
import tarfile
from typing import Callable, List, Optional, Sequence, Union
from typing import Callable, Dict, List, Optional, Sequence, Union

import torchaudio
from torch import nn
Expand All @@ -18,6 +18,58 @@
"""


class AudioProcess:
def __init__(
self,
path: str,
info: Dict,
sample_rate: Optional[int] = None,
transforms: Optional[Callable] = None,
):
self.path = path
self.sample_rate = sample_rate
self.transforms = transforms
self.info = info
self.path_prefix = f"{os.path.splitext(self.path)[0]}_processed"
self.wav_dest_path = None
self.json_dest_path = None

def process_wav(self):
waveform, rate = torchaudio.load(self.path)

if exists(self.sample_rate):
resample = Resample(source=rate, target=self.sample_rate)
waveform = resample(waveform)
rate = self.sample_rate

if exists(self.transforms):
waveform = self.transforms(waveform)

wav_dest_path = f"{self.path_prefix}.wav"
print(wav_dest_path)
torchaudio.save(wav_dest_path, waveform, rate)

self.wav_dest_path = wav_dest_path
return wav_dest_path

def process_info(self):
json_dest_path = f"{self.path_prefix}.json"
with open(json_dest_path, "w") as f:
json.dump(self.info, f)

self.json_dest_path = json_dest_path
return json_dest_path

def __enter__(self):
wav_processed_path = self.process_wav()
json_processed_path = self.process_info()
return wav_processed_path, json_processed_path

def __exit__(self, *args):
os.remove(self.wav_dest_path)
os.remove(self.json_dest_path)


class AudioWebDatasetPreprocess:
def __init__(
self,
Expand Down Expand Up @@ -50,12 +102,12 @@ def str_to_tags(self, str: str) -> List[str]:

async def preprocess(self):
urls, path = self.urls, self.root
tarfile_name = os.path.join(path, f"{self.name}.tar")
tarfile_name = os.path.join(path, f"{self.name}.tar.gz")
waveform_id = 0

async with Downloader(urls, path=path) as files:
async with Decompressor(files, path=path) as folders:
with tarfile.open(tarfile_name, "w") as archive:
with tarfile.open(tarfile_name, "w:gz") as archive:
for folder in tqdm(folders):
for wav in tqdm(glob.glob(folder + "/**/*.wav")):
waveform, rate = torchaudio.load(wav)
Expand Down Expand Up @@ -112,16 +164,13 @@ class AudioWebDataset(WebDataset):

def __init__(
self,
path: Union[str, Sequence[str]],
urls: Union[str, Sequence[str]],
transforms: Optional[Callable] = None,
batch_size: Optional[int] = None,
recursive: bool = True,
shuffle: int = 128,
**kwargs,
):
paths = path if isinstance(path, (list, tuple)) else [path]
tars = get_all_tar_filenames(paths, recursive=recursive)
super().__init__(urls=tars, **kwargs)
super().__init__(urls=urls, **kwargs)

(
self.shuffle(shuffle)
Expand Down
77 changes: 77 additions & 0 deletions audio_data_pytorch/datasets/clotho_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
import tarfile
from typing import Callable, List, Optional

import pandas as pd
from tqdm import tqdm

from ..utils import Decompressor, Downloader, camel_to_snake, run_async
from .audio_web_dataset import AudioProcess, AudioWebDataset


class ClothoDataset(AudioWebDataset):
def __init__(
self,
root: str,
split: str = "train",
preprocess_sample_rate: Optional[int] = None,
preprocess_transforms: Optional[Callable] = None,
reset: bool = False,
**kwargs,
):
self.root = root
self.split = self.split_conversion(split)
self.preprocess_sample_rate = preprocess_sample_rate
self.preprocess_transforms = preprocess_transforms

if not os.path.exists(self.tar_file_name) or reset:
run_async(self.preprocess())

super().__init__(urls=self.tar_file_name, **kwargs)

def split_conversion(self, split: str) -> str:
return {"train": "development", "valid": "evaluation"}[split]

@property
def urls(self) -> List[str]:
return [
f"https://zenodo.org/record/4783391/files/clotho_audio_{self.split}.7z",
f"https://zenodo.org/record/4783391/files/clotho_captions_{self.split}.csv",
]

@property
def data_path(self) -> str:
return os.path.join(self.root, camel_to_snake(self.__class__.__name__))

@property
def tar_file_name(self) -> str:
return os.path.join(self.data_path, f"clotho_{self.split}.tar.gz")

async def preprocess(self):
urls, path = self.urls, self.data_path
waveform_id = 0

async with Downloader(urls, path=path) as files:
to_decompress = [f for f in files if f.endswith(".7z")]
caption_csv_file = [f for f in files if f.endswith(".csv")][0]
async with Decompressor(to_decompress, path=path) as folders:
captions = pd.read_csv(caption_csv_file)
length = len(captions.index)

with tarfile.open(self.tar_file_name, "w:gz") as archive:
for i, caption in tqdm(captions.iterrows(), total=length):
wav_file_name = caption.file_name
wav_path = os.path.join(folders[0], self.split, wav_file_name)
wav_captions = [caption[f"caption_{i}"] for i in range(1, 6)]
info = dict(text=wav_captions)

with AudioProcess(
path=wav_path,
sample_rate=self.preprocess_sample_rate,
transforms=self.preprocess_transforms,
info=info,
) as (wav, json):
archive.add(wav, arcname=f"{waveform_id:06d}.wav")
archive.add(json, arcname=f"{waveform_id:06d}.json")

waveform_id += 1
10 changes: 10 additions & 0 deletions audio_data_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ def is_zip(file_name: str) -> bool:
return file_name.lower().endswith(".zip")


def is_7zip(file_name: str) -> bool:
return file_name.lower().endswith(".7z")


class Decompressor:
def __init__(
self,
Expand Down Expand Up @@ -192,6 +196,12 @@ def decompress(self, file_name: str):
elif is_tar(file_name):
with tarfile.open(file_name) as archive:
self.extract_all(archive, path)
elif is_7zip(file_name):
import py7zr

print(f"{self.description}: {path}")
with py7zr.SevenZipFile(file_name, mode="r") as archive:
archive.extractall(path=path)
else:
raise ValueError(f"Unsupported file extension: {file_name}")
return path
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-data-pytorch",
packages=find_packages(exclude=[]),
version="0.0.12",
version="0.0.13",
license="MIT",
description="Audio Data - PyTorch",
long_description_content_type="text/markdown",
Expand All @@ -18,6 +18,7 @@
"requests",
"tqdm",
"aiohttp",
"webdataset",
],
classifiers=[
"Development Status :: 4 - Beta",
Expand Down

0 comments on commit 146146d

Please sign in to comment.