diff --git a/src/pmqd/torch.py b/src/pmqd/torch.py index 98d2ec6..526f6c9 100644 --- a/src/pmqd/torch.py +++ b/src/pmqd/torch.py @@ -9,7 +9,8 @@ import pandas as pd import torchaudio from torch.utils.data import Dataset -from torchaudio.datasets.utils import download_url, extract_archive +from torch.hub import download_url_to_file as download_url +from torchaudio.datasets.utils import _extract_tar from . import checksums @@ -56,10 +57,10 @@ def __init__( if download: if not self._audio_path.exists(): if not archive.is_file(): - download_url(url_audio, root, hash_value=CHECKSUM_AUDIO) - extract_archive(archive) + download_url(url_audio, str(root), hash_prefix=CHECKSUM_AUDIO) + _extract_tar(str(archive)) if not self._metadata_path.is_file(): - download_url(url_metadata, root, hash_value=CHECKSUM_METADATA) + download_url(url_metadata, str(root), hash_prefix=CHECKSUM_METADATA) if not self._metadata_path.is_file(): raise FileNotFoundError( diff --git a/tests/test_torch.py b/tests/test_torch.py index 6b77d7e..4890caf 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -17,7 +17,7 @@ def test_torchaudio_no_dowload(tmp_path: Path): def test_torchaudio_download(tmp_path: Path, dummy_data: Dict[str, Path]): - def mock_download_url(url: str, download_folder: str, hash_value: str) -> None: + def mock_download_url(url: str, download_folder: str, hash_prefix: str) -> None: filename = os.path.basename(url) shutil.copy(dummy_data[filename], download_folder)