diff --git a/.github/workflows/build_and_test.yaml b/.github/workflows/build_and_test.yaml new file mode 100644 index 0000000..fb92e1c --- /dev/null +++ b/.github/workflows/build_and_test.yaml @@ -0,0 +1,60 @@ +# Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +name: Build and Test + +on: + pull_request: + push: + branches: + - '**' + tags-ignore: + - '**' + +jobs: + build: + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + + steps: + - name: Checkout repo + uses: actions/checkout@v4 + + - name: Create environment with miniconda + uses: conda-incubator/setup-miniconda@v3 + with: + miniforge-variant: Mambaforge + miniforge-version: latest + use-mamba: true + activate-environment: ras + + - name: Get Date + id: get-date + run: echo "today=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_OUTPUT + shell: bash + + - name: Cache Conda env + uses: actions/cache@v3 + with: + path: ${{ env.CONDA }}/envs + key: + conda-${{ runner.os }}--${{ runner.arch }}--${{ + steps.get-date.outputs.today }}-${{ + hashFiles('environment.yaml') }}-${{ env.CACHE_NUMBER + }} + env: + # Increase this value to reset cache if etc/example-environment.yml has not changed + CACHE_NUMBER: 0 + id: cache + + - name: Update environment + run: mamba env update -n ras -f environment.yaml + if: steps.cache.outputs.cache-hit != 'true' + + - name: Run unit tests + run: | + pip install pytest + pytest diff --git a/.github/workflows/static_checks.yaml b/.github/workflows/static_checks.yaml new file mode 100644 index 0000000..2db7f87 --- /dev/null +++ b/.github/workflows/static_checks.yaml @@ -0,0 +1,77 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +name: Static code checks + +on: + pull_request: + push: + branches: + - '**' + tags-ignore: + - '**' + +env: + LICENSE: AGPL-3.0-or-later + FETCH_DEPTH: 1 + FULL_HISTORY: 0 + SKIP_WORD_PRESENCE_CHECK: 0 + +jobs: + static-code-check: + if: endsWith(github.event.repository.name, 'private') + + name: Run static code checks + # See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu1804-Readme.md for list of packages + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + + steps: + - name: Setup history + if: github.ref == 'refs/heads/oss' + run: | + echo "FETCH_DEPTH=0" >> $GITHUB_ENV + echo "FULL_HISTORY=1" >> $GITHUB_ENV + + - name: Setup version + if: github.ref == 'refs/heads/melco' + run: | + echo "SKIP_WORD_PRESENCE_CHECK=1" >> $GITHUB_ENV + + - name: Check out code + uses: actions/checkout@v4 + with: + fetch-depth: ${{ env.FETCH_DEPTH }} # '0' to check full history + + - name: Set up environment + run: git config user.email github-bot@merl.com + + - name: Set up python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + cache-dependency-path: 'requirements-dev.txt' + + - name: Install python packages + run: pip install -r requirements-dev.txt + + - name: Ensure lint and pre-commit steps have been run + uses: pre-commit/action@v3.0.1 + + - name: Check files + uses: merl-oss-private/merl-file-check-action@v1 + with: + license: ${{ env.LICENSE }} + full-history: ${{ env.FULL_HISTORY }} # If true, use fetch-depth 0 above + skip-word-presence-check: ${{ env.SKIP_WORD_PRESENCE_CHECK }} + + - name: Check license compatibility + if: github.ref != 'refs/heads/melco' + uses: merl-oss-private/merl_license_compatibility_checker@v1 + with: + input-filename: environment.yaml + license: ${{ env.LICENSE }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d2a45c9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,177 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# vscode +.vscode/ + +# Experiment +lightning_logs/ +log/ +exp*/ +evaluation/ +*.ckpt +*.png +*.wav +*.pdf diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..b03c347 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,63 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later +# +# Pre-commit configuration. See https://pre-commit.com + +default_language_version: + python: python3.11 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-yaml + - id: check-added-large-files + args: ['--maxkb=1000'] + + - repo: https://gitlab.com/bmares/check-json5 + rev: v1.0.0 + hooks: + - id: check-json5 + + - repo: https://github.com/homebysix/pre-commit-macadmin + rev: v1.16.2 + hooks: + - id: check-git-config-email + args: ['--domains', 'merl.com'] + + - repo: https://github.com/psf/black + rev: 24.4.2 + hooks: + - id: black + args: + - --line-length=120 + + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files", "--line-length", "120", "--skip-gitignore"] + + # Uncomment to use pyupgrade (https://github.com/asottile/pyupgrade) to automatically upgrade syntax for newer python + # - repo: https://github.com/asottile/pyupgrade + # rev: v3.3.1 + # hooks: + # - id: pyupgrade + + # To stop flake8 error from causing a failure, use --exit-zero. By default, pre-commit will not show the warnings, + # so use verbose: true to see them. + - repo: https://github.com/pycqa/flake8 + rev: 7.1.0 + hooks: + - id: flake8 + # Black compatibility, Eradicate options + args: ["--max-line-length=119", "--extend-ignore=E203", "--eradicate-whitelist-extend", "eradicate:\\s*no", + "--exit-zero"] + verbose: true + additional_dependencies: [ + # https://github.com/myint/eradicate, https://github.com/wemake-services/flake8-eradicate + "flake8-eradicate" + ] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..bf6e50a --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,8 @@ + +# Contributing + +Sorry, but we do not currently accept contributions in the form of pull requests to this repository. However, you are welcome to post issues (bug reports, feature requests, questions, etc). diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..cba6f6a --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,660 @@ +### GNU AFFERO GENERAL PUBLIC LICENSE + +Version 3, 19 November 2007 + +Copyright (C) 2007 Free Software Foundation, Inc. + + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + +### Preamble + +The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + +The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains +free software for all its users. + +When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + +Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + +A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + +The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + +An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing +under this license. + +The precise terms and conditions for copying, distribution and +modification follow. + +### TERMS AND CONDITIONS + +#### 0. Definitions. + +"This License" refers to version 3 of the GNU Affero General Public +License. + +"Copyright" also means copyright-like laws that apply to other kinds +of works, such as semiconductor masks. + +"The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + +To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of +an exact copy. The resulting work is called a "modified version" of +the earlier work or a work "based on" the earlier work. + +A "covered work" means either the unmodified Program or a work based +on the Program. + +To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + +To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user +through a computer network, with no transfer of a copy, is not +conveying. + +An interactive user interface displays "Appropriate Legal Notices" to +the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + +#### 1. Source Code. + +The "source code" for a work means the preferred form of the work for +making modifications to it. "Object code" means any non-source form of +a work. + +A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + +The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + +The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + +The Corresponding Source need not include anything that users can +regenerate automatically from other parts of the Corresponding Source. + +The Corresponding Source for a work in source code form is that same +work. + +#### 2. Basic Permissions. + +All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + +You may make, run and propagate covered works that you do not convey, +without conditions so long as your license otherwise remains in force. +You may convey covered works to others for the sole purpose of having +them make modifications exclusively for you, or provide you with +facilities for running those works, provided that you comply with the +terms of this License in conveying all material for which you do not +control copyright. Those thus making or running the covered works for +you must do so exclusively on your behalf, under your direction and +control, on terms that prohibit them from making any copies of your +copyrighted material outside their relationship with you. + +Conveying under any other circumstances is permitted solely under the +conditions stated below. Sublicensing is not allowed; section 10 makes +it unnecessary. + +#### 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + +No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + +When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such +circumvention is effected by exercising rights under this License with +respect to the covered work, and you disclaim any intention to limit +operation or modification of the work as a means of enforcing, against +the work's users, your or third parties' legal rights to forbid +circumvention of technological measures. + +#### 4. Conveying Verbatim Copies. + +You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + +You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + +#### 5. Conveying Modified Source Versions. + +You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these +conditions: + +- a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. +- b) The work must carry prominent notices stating that it is + released under this License and any conditions added under + section 7. This requirement modifies the requirement in section 4 + to "keep intact all notices". +- c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. +- d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + +A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + +#### 6. Conveying Non-Source Forms. + +You may convey a covered work in object code form under the terms of +sections 4 and 5, provided that you also convey the machine-readable +Corresponding Source under the terms of this License, in one of these +ways: + +- a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. +- b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the Corresponding + Source from a network server at no charge. +- c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. +- d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. +- e) Convey the object code using peer-to-peer transmission, + provided you inform other peers where the object code and + Corresponding Source of the work are being offered to the general + public at no charge under subsection 6d. + +A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + +A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, +family, or household purposes, or (2) anything designed or sold for +incorporation into a dwelling. In determining whether a product is a +consumer product, doubtful cases shall be resolved in favor of +coverage. For a particular product received by a particular user, +"normally used" refers to a typical or common use of that class of +product, regardless of the status of the particular user or of the way +in which the particular user actually uses, or expects or is expected +to use, the product. A product is a consumer product regardless of +whether the product has substantial commercial, industrial or +non-consumer uses, unless such uses represent the only significant +mode of use of the product. + +"Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to +install and execute modified versions of a covered work in that User +Product from a modified version of its Corresponding Source. The +information must suffice to ensure that the continued functioning of +the modified object code is in no case prevented or interfered with +solely because modification has been made. + +If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + +The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or +updates for a work that has been modified or installed by the +recipient, or for the User Product in which it has been modified or +installed. Access to a network may be denied when the modification +itself materially and adversely affects the operation of the network +or violates the rules and protocols for communication across the +network. + +Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + +#### 7. Additional Terms. + +"Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + +When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + +Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders +of that material) supplement the terms of this License with terms: + +- a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or +- b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or +- c) Prohibiting misrepresentation of the origin of that material, + or requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or +- d) Limiting the use for publicity purposes of names of licensors + or authors of the material; or +- e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or +- f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions + of it) with contractual assumptions of liability to the recipient, + for any liability that these contractual assumptions directly + impose on those licensors and authors. + +All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + +If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + +Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; the +above requirements apply either way. + +#### 8. Termination. + +You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + +However, if you cease all violation of this License, then your license +from a particular copyright holder is reinstated (a) provisionally, +unless and until the copyright holder explicitly and finally +terminates your license, and (b) permanently, if the copyright holder +fails to notify you of the violation by some reasonable means prior to +60 days after the cessation. + +Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + +Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + +#### 9. Acceptance Not Required for Having Copies. + +You are not required to accept this License in order to receive or run +a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + +#### 10. Automatic Licensing of Downstream Recipients. + +Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + +An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + +You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + +#### 11. Patents. + +A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + +A contributor's "essential patent claims" are all patent claims owned +or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + +Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + +In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + +If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + +If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + +A patent license is "discriminatory" if it does not include within the +scope of its coverage, prohibits the exercise of, or is conditioned on +the non-exercise of one or more of the rights that are specifically +granted under this License. You may not convey a covered work if you +are a party to an arrangement with a third party that is in the +business of distributing software, under which you make payment to the +third party based on the extent of your activity of conveying the +work, and under which the third party grants, to any of the parties +who would receive the covered work from you, a discriminatory patent +license (a) in connection with copies of the covered work conveyed by +you (or copies made from those copies), or (b) primarily for and in +connection with specific products or compilations that contain the +covered work, unless you entered into that arrangement, or that patent +license was granted, prior to 28 March 2007. + +Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + +#### 12. No Surrender of Others' Freedom. + +If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under +this License and any other pertinent obligations, then as a +consequence you may not convey it at all. For example, if you agree to +terms that obligate you to collect a royalty for further conveying +from those to whom you convey the Program, the only way you could +satisfy both those terms and this License would be to refrain entirely +from conveying the Program. + +#### 13. Remote Network Interaction; Use with the GNU General Public License. + +Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your +version supports such interaction) an opportunity to receive the +Corresponding Source of your version by providing access to the +Corresponding Source from a network server at no charge, through some +standard or customary means of facilitating copying of software. This +Corresponding Source shall include the Corresponding Source for any +work covered by version 3 of the GNU General Public License that is +incorporated pursuant to the following paragraph. + +Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + +#### 14. Revised Versions of this License. + +The Free Software Foundation may publish revised and/or new versions +of the GNU Affero General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + +Each version is given a distinguishing version number. If the Program +specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever +published by the Free Software Foundation. + +If the Program specifies that a proxy can decide which future versions +of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + +Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + +#### 15. Disclaimer of Warranty. + +THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT +WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND +PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE +DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR +CORRECTION. + +#### 16. Limitation of Liability. + +IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR +CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, +INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES +ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT +NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR +LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM +TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER +PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +#### 17. Interpretation of Sections 15 and 16. + +If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + +END OF TERMS AND CONDITIONS + +### How to Apply These Terms to Your New Programs + +If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these +terms. + +To do so, attach the following notices to the program. It is safest to +attach them to the start of each source file to most effectively state +the exclusion of warranty; and each file should have at least the +"copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as + published by the Free Software Foundation, either version 3 of the + License, or (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper +mail. + +If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for +the specific requirements. + +You should also get your employer (if you work as a programmer) or +school, if any, to sign a "copyright disclaimer" for the program, if +necessary. For more information on this, and how to apply and follow +the GNU AGPL, see . diff --git a/LICENSES/Apache-2.0.md b/LICENSES/Apache-2.0.md new file mode 100644 index 0000000..6afea95 --- /dev/null +++ b/LICENSES/Apache-2.0.md @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2017 Johns Hopkins University (Shinji Watanabe) + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..bc2dbff --- /dev/null +++ b/README.md @@ -0,0 +1,104 @@ + +# Enhanced Reverberation as Supervision for Unsupervised Speech Separation + +This repository includes source code for training and evaluating the enhanced reverberation as supervision (ERAS), proposed in the following Interspeech 2024 paper: + +``` +@InProceedings{Saijo2024_eras, + author = {Saijo, Kohei and Wichern, Gordon and Germain, Fran\c{c}ois G. and Pan, Zexu and {Le Roux}, Jonathan}, + title = {Enhanced Reverberation as Supervision for Unsupervised Speech Separation}, + booktitle = {Proc. Annual Conference of International Speech Communication Association (INTERSPEECH)}, + year = 2024, + month = sep +} +``` + +## Table of contents + +1. [Installation](#installation) +2. [How to run](#how-to-run) +3. [Contributing](#contributing) +4. [Copyright and license](#copyright-and-license) + +## Installation + +Clone this repo and create the anaconda environment + +```sh +git clone https://github.com/merlresearch/reverberation-as-supervision +cd reverberation-as-supervision && conda env create -f environment.yaml +``` + +## How to run + +This repository supports training on two datasets used in the paper, **WHAMR!** and **SMS-WSJ**. +Example training configuration files are under `./configs/*dataset-name*`. + +Before starting training, run the following command: + +```sh +conda activate ras +``` + +The main script for training is in `train.py`, which can be run by + +```sh +python train.py --config /path/to/config --data_path /path/to/data +``` + +Here, `/path/to/data` is the directory containing `wav8k` and `wav16k` directories for WHAMR! and that containing `sms_wsj.json` for SMS-WSJ. + +As demonstrated in the paper, a best-performing model is obtained by two-stage training. +One can first pre-train a model and then fine-tune it as follows (example commands on WHAMR!). + +```sh +# Train a model with ISMS-loss weight of 0.3 for 20 epochs. +python train.py --config ./configs/whamr/eras_whamr_isms0.3_icc0.0.yaml --data_path /path/to/whamr + +# Fine-tune the pre-trained model without the ISMS loss and with the ICC loss for 80 epochs. +# Note that the pre-trained model's path has to be specified in the yaml file. +python train.py --config ./configs/whamr/eras_whamr_isms0.0_icc0.1.yaml --data_path /path/to/whamr +``` + +The checkpoints and tensorboard logs are saved under `exp/eras/*config-name*` directory. +After finishing the training, separation performance can be evaluated using `eval.py`: + +```sh +python eval.py --ckpt_path /path/to/.ckpt-file --data_path /path/to/data +``` + +The evaluation scores are logged in the tensorboard. + +## Contributing + +See [CONTRIBUTING.md](CONTRIBUTING.md) for our policy on contributions. + +## Copyright and license + +Released under `AGPL-3.0-or-later` license, as found in the [LICENSE.md](LICENSE.md) file. + +All files, except as noted below: + +``` +Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) + +SPDX-License-Identifier: AGPL-3.0-or-later +``` + +The following file: + +- `nets/tfgridnetv2.py` + +was adapted from https://github.com/espnet/espnet (license included in [LICENSES/Apache-2.0.md](LICENSES/Apache-2.0.md)) + +``` +Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +Copyright (C) 2023 ESPnet Developers + +SPDX-License-Identifier: AGPL-3.0-or-later +SPDX-License-Identifier: Apache-2.0 +``` diff --git a/configs/smswsj/eras_smswsj_isms0.0_icc0.1.yaml b/configs/smswsj/eras_smswsj_isms0.0_icc0.1.yaml new file mode 100644 index 0000000..21e5872 --- /dev/null +++ b/configs/smswsj/eras_smswsj_isms0.0_icc0.1.yaml @@ -0,0 +1,94 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +# Basic training config +batch_size: 2 # batch size is doubled in dataloader in ERAS training (becomes 4 in forward path). +val_batch_size: 4 # cv batch size +seed: 1128 # seed for initializing training +shuffle: true # shuffle training dataset during training +num_workers: 4 # number of workers in dataloaders, 0 for single thread + +# trainer args +trainer_conf: + max_epochs: 80 + limit_train_batches: 0.5 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + gradient_clip_val: 1.0 + +# early stopping configurations +early_stopping: null + +# checkpoint saving +model_checkpoint: + monitor: val/loss + save_top_k: 1 + mode: min + save_weights_only: false + save_last: true + +# train/dev dataset +dataset_name: smswsj +dataset_conf: + placeholder: null # just a placeholder + +# torch dataset +dataloading_conf: + sr: 8000 + chunk_size: 4 + chunking_strategy: random + ref_channel: &ref_channel 0 + channel_idx: [0, 1] + running_eras: true + normalization: true + +stft_conf: &stft_conf + fft_size: &fft_size 256 + window_length: *fft_size + hop_length: 64 + window_type: sqrt_hann + normalize: window + +eras_loss_conf: + loss_func: complex_l1 + past_taps: 19 + future_taps: 1 + ref_channel_loss_weight: 0.0 + isms_loss_weight: 0.0 + icc_loss_weight: 0.1 + ref_channel: *ref_channel + unsupervised: true + supervised_loss_type: after_filtering_ref_channel + stft_conf: *stft_conf + +# Network parameters +model_name: tfgridnetv2 +model_conf: + fft_size: *fft_size + n_srcs: 2 + n_imics: 1 + n_layers: 4 + lstm_hidden_units: 256 + attn_n_head: 4 + attn_approx_qk_dim: 512 + emb_dim: 48 + emb_ks: 4 + emb_hs: 1 + eps: 1.0e-5 + +# Adam optimizer and reducelronplteau scheduler +optimizer_conf: + lr: 1.0e-3 +scheduler_conf: + patience: 3 + factor: 0.5 + +# fine-tuning config +pretrained_model_path: ./exp/eras/eras_smswsj_isms0.3_icc0.0/checkpoints/last.ckpt +warmup_steps: 4000 + +# General parameters +output_audio: true # save wav files of separated sources during testing in "audio_output" directory +log_file: "" diff --git a/configs/smswsj/eras_smswsj_isms0.3_icc0.0.yaml b/configs/smswsj/eras_smswsj_isms0.3_icc0.0.yaml new file mode 100644 index 0000000..2095b37 --- /dev/null +++ b/configs/smswsj/eras_smswsj_isms0.3_icc0.0.yaml @@ -0,0 +1,92 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +# Basic training config +batch_size: 2 # batch size is doubled in dataloader in ERAS training (becomes 4 in forward path). +val_batch_size: 4 # cv batch size +seed: 1128 # seed for initializing training +shuffle: true # shuffle training dataset during training +num_workers: 4 # number of workers in dataloaders, 0 for single thread + +# trainer args +trainer_conf: + max_epochs: 20 + limit_train_batches: 0.5 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + gradient_clip_val: 1.0 + +# early stopping configurations +early_stopping: null + +# checkpoint saving +model_checkpoint: + monitor: val/loss + save_top_k: 1 + mode: min + save_weights_only: false + save_last: true + +# train/dev dataset +dataset_name: smswsj +dataset_conf: + placeholder: null # just a placeholder + +# torch dataset +dataloading_conf: + sr: 8000 + chunk_size: 4 + chunking_strategy: random + ref_channel: &ref_channel 0 + channel_idx: [0, 1] + running_eras: true + normalization: true + +stft_conf: &stft_conf + fft_size: &fft_size 256 + window_length: *fft_size + hop_length: 64 + window_type: sqrt_hann + normalize: window + +eras_loss_conf: + loss_func: complex_l1 + past_taps: 19 + future_taps: 1 + ref_channel_loss_weight: 0.0 + isms_loss_weight: 0.3 + icc_loss_weight: 0.0 + ref_channel: *ref_channel + unsupervised: true + supervised_loss_type: after_filtering_ref_channel # used for validation + stft_conf: *stft_conf + +# Network parameters +model_name: tfgridnetv2 +model_conf: + fft_size: *fft_size + n_srcs: 2 + n_imics: 1 + n_layers: 4 + lstm_hidden_units: 256 + attn_n_head: 4 + attn_approx_qk_dim: 512 + emb_dim: 48 + emb_ks: 4 + emb_hs: 1 + eps: 1.0e-5 + +# Adam optimizer and reducelronplteau scheduler +optimizer_conf: + lr: 1.0e-3 +scheduler_conf: + patience: 3 + factor: 0.5 + +pretrained_model_path: null + + +# General parameters +log_file: "" diff --git a/configs/smswsj/unssor_smswsj_input1chloss3ch_isms0.06_refmic0.1666.yaml b/configs/smswsj/unssor_smswsj_input1chloss3ch_isms0.06_refmic0.1666.yaml new file mode 100644 index 0000000..53206a8 --- /dev/null +++ b/configs/smswsj/unssor_smswsj_input1chloss3ch_isms0.06_refmic0.1666.yaml @@ -0,0 +1,92 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +# Basic training config +batch_size: 4 # mini-batch size +val_batch_size: 4 # cv batch size +seed: 1128 # seed for initializing training +shuffle: true # shuffle training dataset during training +num_workers: 4 # number of workers in dataloaders, 0 for single thread + +# trainer args +trainer_conf: + max_epochs: 100 + limit_train_batches: 1.0 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + gradient_clip_val: 1.0 + +# early stopping configurations +early_stopping: null + +# checkpoint saving +model_checkpoint: + monitor: val/loss + save_top_k: 1 + mode: min + save_weights_only: false + save_last: true + +# train/dev dataset +dataset_name: smswsj +dataset_conf: + placeholder: null # just a placeholder + +# torch dataset +dataloading_conf: + sr: 8000 + chunk_size: 4 + chunking_strategy: random + ref_channel: &ref_channel 0 + channel_idx: [0, 2, 4] + running_eras: false + normalization: true + +stft_conf: &stft_conf + fft_size: &fft_size 256 + window_length: *fft_size + hop_length: 64 + window_type: sqrt_hann + normalize: window + +eras_loss_conf: + loss_func: complex_l1 + past_taps: 19 + future_taps: 1 + ref_channel_loss_weight: 0.166666 # 1/6 + isms_loss_weight: 0.06 + icc_loss_weight: 0.0 + ref_channel: *ref_channel + unsupervised: true + supervised_loss_type: after_filtering_ref_channel # used for validation + stft_conf: *stft_conf + +# Network parameters +model_name: tfgridnetv2 +model_conf: + fft_size: *fft_size + n_srcs: 2 + n_imics: 1 + n_layers: 4 + lstm_hidden_units: 256 + attn_n_head: 4 + attn_approx_qk_dim: 512 + emb_dim: 48 + emb_ks: 4 + emb_hs: 1 + eps: 1.0e-5 + +# Adam optimizer and reducelronplteau scheduler +optimizer_conf: + lr: 1.0e-3 +scheduler_conf: + patience: 2 + factor: 0.5 + +pretrained_model_path: null + + +# General parameters +log_file: "" diff --git a/configs/smswsj/unssor_smswsj_input1chloss6ch_isms0.02_refmic0.2.yaml b/configs/smswsj/unssor_smswsj_input1chloss6ch_isms0.02_refmic0.2.yaml new file mode 100644 index 0000000..3218faf --- /dev/null +++ b/configs/smswsj/unssor_smswsj_input1chloss6ch_isms0.02_refmic0.2.yaml @@ -0,0 +1,92 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +# Basic training config +batch_size: 4 # mini-batch size +val_batch_size: 4 # cv batch size +seed: 1128 # seed for initializing training +shuffle: true # shuffle training dataset during training +num_workers: 4 # number of workers in dataloaders, 0 for single thread + +# trainer args +trainer_conf: + max_epochs: 100 + limit_train_batches: 1.0 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + gradient_clip_val: 1.0 + +# early stopping configurations +early_stopping: null + +# checkpoint saving +model_checkpoint: + monitor: val/loss + save_top_k: 1 + mode: min + save_weights_only: false + save_last: true + +# train/dev dataset +dataset_name: smswsj +dataset_conf: + placeholder: null # just a placeholder + +# torch dataset +dataloading_conf: + sr: 8000 + chunk_size: 4 + chunking_strategy: random + ref_channel: &ref_channel 0 + channel_idx: [0, 1, 2, 3, 4, 5] + running_eras: false + normalization: true + +stft_conf: &stft_conf + fft_size: &fft_size 256 + window_length: *fft_size + hop_length: 64 + window_type: sqrt_hann + normalize: window + +eras_loss_conf: + loss_func: complex_l1 + past_taps: 19 + future_taps: 1 + ref_channel_loss_weight: 0.2 + isms_loss_weight: 0.02 + icc_loss_weight: 0.0 + ref_channel: *ref_channel + unsupervised: true + supervised_loss_type: after_filtering_ref_channel # used for validation + stft_conf: *stft_conf + +# Network parameters +model_name: tfgridnetv2 +model_conf: + fft_size: *fft_size + n_srcs: 2 + n_imics: 1 + n_layers: 4 + lstm_hidden_units: 256 + attn_n_head: 4 + attn_approx_qk_dim: 512 + emb_dim: 48 + emb_ks: 4 + emb_hs: 1 + eps: 1.0e-5 + +# Adam optimizer and reducelronplteau scheduler +optimizer_conf: + lr: 1.0e-3 +scheduler_conf: + patience: 2 + factor: 0.5 + +pretrained_model_path: null + + +# General parameters +log_file: "" diff --git a/configs/smswsj/unssor_smswsj_input3chloss3ch_isms0.06_refmic0.1666.yaml b/configs/smswsj/unssor_smswsj_input3chloss3ch_isms0.06_refmic0.1666.yaml new file mode 100644 index 0000000..c64524e --- /dev/null +++ b/configs/smswsj/unssor_smswsj_input3chloss3ch_isms0.06_refmic0.1666.yaml @@ -0,0 +1,92 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +# Basic training config +batch_size: 4 # mini-batch size +val_batch_size: 4 # cv batch size +seed: 1128 # seed for initializing training +shuffle: true # shuffle training dataset during training +num_workers: 4 # number of workers in dataloaders, 0 for single thread + +# trainer args +trainer_conf: + max_epochs: 100 + limit_train_batches: 1.0 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + gradient_clip_val: 1.0 + +# early stopping configurations +early_stopping: null + +# checkpoint saving +model_checkpoint: + monitor: val/loss + save_top_k: 1 + mode: min + save_weights_only: false + save_last: true + +# train/dev dataset +dataset_name: smswsj +dataset_conf: + placeholder: null # just a placeholder + +# torch dataset +dataloading_conf: + sr: 8000 + chunk_size: 4 + chunking_strategy: random + ref_channel: &ref_channel 0 + channel_idx: [0, 2, 4] + running_eras: false + normalization: true + +stft_conf: &stft_conf + fft_size: &fft_size 256 + window_length: *fft_size + hop_length: 64 + window_type: sqrt_hann + normalize: window + +eras_loss_conf: + loss_func: complex_l1 + past_taps: 19 + future_taps: 1 + ref_channel_loss_weight: 0.166666 # 1/6 + isms_loss_weight: 0.06 + icc_loss_weight: 0.0 + ref_channel: *ref_channel + unsupervised: true + supervised_loss_type: after_filtering_ref_channel # used for validation + stft_conf: *stft_conf + +# Network parameters +model_name: tfgridnetv2 +model_conf: + fft_size: *fft_size + n_srcs: 2 + n_imics: 3 + n_layers: 4 + lstm_hidden_units: 256 + attn_n_head: 4 + attn_approx_qk_dim: 512 + emb_dim: 48 + emb_ks: 4 + emb_hs: 1 + eps: 1.0e-5 + +# Adam optimizer and reducelronplteau scheduler +optimizer_conf: + lr: 1.0e-3 +scheduler_conf: + patience: 2 + factor: 0.5 + +pretrained_model_path: null + + +# General parameters +log_file: "" diff --git a/configs/smswsj/unssor_smswsj_input6chloss6ch_isms0.02_refmic0.2.yaml b/configs/smswsj/unssor_smswsj_input6chloss6ch_isms0.02_refmic0.2.yaml new file mode 100644 index 0000000..5893655 --- /dev/null +++ b/configs/smswsj/unssor_smswsj_input6chloss6ch_isms0.02_refmic0.2.yaml @@ -0,0 +1,92 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +# Basic training config +batch_size: 4 # mini-batch size +val_batch_size: 4 # cv batch size +seed: 1128 # seed for initializing training +shuffle: true # shuffle training dataset during training +num_workers: 4 # number of workers in dataloaders, 0 for single thread + +# trainer args +trainer_conf: + max_epochs: 100 + limit_train_batches: 1.0 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + gradient_clip_val: 1.0 + +# early stopping configurations +early_stopping: null + +# checkpoint saving +model_checkpoint: + monitor: val/loss + save_top_k: 1 + mode: min + save_weights_only: false + save_last: true + +# train/dev dataset +dataset_name: smswsj +dataset_conf: + placeholder: null # just a placeholder + +# torch dataset +dataloading_conf: + sr: 8000 + chunk_size: 4 + chunking_strategy: random + ref_channel: &ref_channel 0 + channel_idx: [0, 1, 2, 3, 4, 5] + running_eras: false + normalization: true + +stft_conf: &stft_conf + fft_size: &fft_size 256 + window_length: *fft_size + hop_length: 64 + window_type: sqrt_hann + normalize: window + +eras_loss_conf: + loss_func: complex_l1 + past_taps: 19 + future_taps: 1 + ref_channel_loss_weight: 0.2 + isms_loss_weight: 0.02 + icc_loss_weight: 0.0 + ref_channel: *ref_channel + unsupervised: true + supervised_loss_type: after_filtering_ref_channel # used for validation + stft_conf: *stft_conf + +# Network parameters +model_name: tfgridnetv2 +model_conf: + fft_size: *fft_size + n_srcs: 2 + n_imics: 6 + n_layers: 4 + lstm_hidden_units: 256 + attn_n_head: 4 + attn_approx_qk_dim: 512 + emb_dim: 48 + emb_ks: 4 + emb_hs: 1 + eps: 1.0e-5 + +# Adam optimizer and reducelronplteau scheduler +optimizer_conf: + lr: 1.0e-3 +scheduler_conf: + patience: 2 + factor: 0.5 + +pretrained_model_path: null + + +# General parameters +log_file: "" diff --git a/configs/whamr/eras_whamr_isms0.0_icc0.1.yaml b/configs/whamr/eras_whamr_isms0.0_icc0.1.yaml new file mode 100644 index 0000000..c0ea0ec --- /dev/null +++ b/configs/whamr/eras_whamr_isms0.0_icc0.1.yaml @@ -0,0 +1,96 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +# Basic training config +batch_size: 2 # batch size is doubled in dataloader in ERAS training (becomes 4 in forward path). +val_batch_size: 4 # cv batch size +seed: 1128 # seed for initializing training +shuffle: true # shuffle training dataset during training +num_workers: 4 # number of workers in dataloaders, 0 for single thread + +# trainer args +trainer_conf: + max_epochs: 80 + limit_train_batches: 0.5 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + gradient_clip_val: 1.0 + +# early stopping configurations +early_stopping: null + +# checkpoint saving +model_checkpoint: + monitor: val/loss + save_top_k: 1 + mode: min + save_weights_only: false + save_last: true + +# train/dev dataset +dataset_name: whamr +dataset_conf: + task: sep_reverb_all_srcs + use_min: true + +# torch dataset +dataloading_conf: + sr: 8000 + chunk_size: 4 + chunking_strategy: random + ref_channel: &ref_channel 0 + channel_idx: [0, 1] + running_eras: true + normalization: true + +stft_conf: &stft_conf + fft_size: &fft_size 256 + window_length: *fft_size + hop_length: 64 + window_type: sqrt_hann + normalize: window + +eras_loss_conf: + loss_func: complex_l1 + past_taps: 19 + future_taps: 1 + ref_channel_loss_weight: 0.0 + isms_loss_weight: 0.0 + icc_loss_weight: 0.1 + ref_channel: *ref_channel + unsupervised: true + supervised_loss_type: after_filtering_ref_channel + stft_conf: *stft_conf + +# Network parameters +model_name: tfgridnetv2 +model_conf: + fft_size: *fft_size + n_srcs: 2 + n_imics: 1 + n_layers: 4 + lstm_hidden_units: 256 + attn_n_head: 4 + attn_approx_qk_dim: 512 + emb_dim: 48 + emb_ks: 4 + emb_hs: 1 + eps: 1.0e-5 + +# Adam optimizer and reducelronplteau scheduler +optimizer_conf: + lr: 1.0e-3 +scheduler_conf: + patience: 3 + factor: 0.5 + +# fine-tuning config +pretrained_model_path: ./exp/eras/eras_whamr_isms0.3_icc0.0/checkpoints/last.ckpt +warmup_steps: 4000 + + +# General parameters +output_audio: true # save wav files of separated sources during testing in "audio_output" directory +log_file: "" diff --git a/configs/whamr/eras_whamr_isms0.3_icc0.0.yaml b/configs/whamr/eras_whamr_isms0.3_icc0.0.yaml new file mode 100644 index 0000000..36f3b44 --- /dev/null +++ b/configs/whamr/eras_whamr_isms0.3_icc0.0.yaml @@ -0,0 +1,94 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +# Basic training config +batch_size: 2 # batch size is doubled in dataloader in ERAS training (becomes 4 in forward path). +val_batch_size: 4 # cv batch size +seed: 1128 # seed for initializing training +shuffle: true # shuffle training dataset during training +num_workers: 4 # number of workers in dataloaders, 0 for single thread + +# trainer args +trainer_conf: + max_epochs: 20 + limit_train_batches: 0.5 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + gradient_clip_val: 1.0 + +# early stopping configurations +early_stopping: null + +# checkpoint saving +model_checkpoint: + monitor: val/loss + save_top_k: 1 + mode: min + save_weights_only: false + save_last: true + +# train/dev dataset +dataset_name: whamr +dataset_conf: + task: sep_reverb_all_srcs + use_min: true + +# torch dataset +dataloading_conf: + sr: 8000 + chunk_size: 4 + chunking_strategy: random + ref_channel: &ref_channel 0 + channel_idx: [0, 1] + running_eras: true + normalization: true + +stft_conf: &stft_conf + fft_size: &fft_size 256 + window_length: *fft_size + hop_length: 64 + window_type: sqrt_hann + normalize: window + +eras_loss_conf: + loss_func: complex_l1 + past_taps: 19 + future_taps: 1 + ref_channel_loss_weight: 0.0 + isms_loss_weight: 0.3 + icc_loss_weight: 0.0 + ref_channel: *ref_channel + unsupervised: true + supervised_loss_type: after_filtering_ref_channel # used for validation + stft_conf: *stft_conf + +# Network parameters +model_name: tfgridnetv2 +model_conf: + fft_size: *fft_size + n_srcs: 2 + n_imics: 1 + n_layers: 4 + lstm_hidden_units: 256 + attn_n_head: 4 + attn_approx_qk_dim: 512 + emb_dim: 48 + emb_ks: 4 + emb_hs: 1 + eps: 1.0e-5 + +# Adam optimizer and reducelronplteau scheduler +optimizer_conf: + lr: 1.0e-3 +scheduler_conf: + patience: 3 + factor: 0.5 + +# fine-tuning config +pretrained_model_path: null + + +# General parameters +log_file: "" diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000..8fccc03 --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/datasets/dataset_creator.py b/datasets/dataset_creator.py new file mode 100644 index 0000000..75a7a3c --- /dev/null +++ b/datasets/dataset_creator.py @@ -0,0 +1,34 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +from .paths.smswsj import get_smswsj_paths +from .paths.whamr import get_whamr_paths +from .stft_dataset import STFTDataset + + +def dataset_creator(hparams, data_path, partition): + path_list = datapath_creator(hparams.dataset_name, data_path, partition, hparams.dataset_conf) + + # some setups specific for training stage + is_training = partition == "tr" + + dataset = STFTDataset( + path_list, + is_training, + hparams.stft_conf, + **hparams.dataloading_conf, + ) + return dataset + + +def datapath_creator(dataset_name, data_path, partition, dataset_conf): + dataset_conf["partition"] = partition + if dataset_name == "whamr": + path_list = get_whamr_paths(data_path, **dataset_conf) + elif dataset_name == "smswsj": + path_list = get_smswsj_paths(data_path, **dataset_conf) + else: + raise ValueError("Dataset {} not currently supported.".format(dataset_name)) + return path_list diff --git a/datasets/paths/__init__.py b/datasets/paths/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/datasets/paths/smswsj.py b/datasets/paths/smswsj.py new file mode 100644 index 0000000..f29eaf4 --- /dev/null +++ b/datasets/paths/smswsj.py @@ -0,0 +1,55 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import json + +PARTITION_MAP = { + "tr": {"name": "train_si284", "num_data": 33561}, + "cv": {"name": "cv_dev93", "num_data": 982}, + "tt": {"name": "test_eval92", "num_data": 1332}, +} + + +def get_smswsj_paths( + data_path, + partition="tt", + **kwargs, +): + # rename partition + stage = PARTITION_MAP[partition]["name"] + + # load metadata of smswsj + with open(data_path / "sms_wsj.json") as f: + mixinfo = json.load(f) + mixinfo = mixinfo["datasets"][stage] + + pathlist = [] + for key, info in mixinfo.items(): + # information of a sample + tmp = { + "id": key, + "mix": info["audio_path"]["observation"], + "srcs": { + "reverb1": info["audio_path"]["speech_image"][0], + "reverb2": info["audio_path"]["speech_image"][1], + "anechoic1": info["audio_path"]["speech_reverberation_early"][0], + "anechoic2": info["audio_path"]["speech_reverberation_early"][1], + "dry1": info["audio_path"]["speech_source"][0], + "dry2": info["audio_path"]["speech_source"][1], + }, + "num_samples": info["num_samples"]["speech_source"], + "offset": info["offset"], + } + # add to pathlist + pathlist.append(tmp) + # check number of data + assert len(pathlist) == PARTITION_MAP[partition]["num_data"] + + if partition == "cv": + # sort by length for efficient validation + # decreasing padded-zeros would lead to acccurate validation + pathlist = sorted(pathlist, key=lambda x: x["num_samples"]) + + return pathlist diff --git a/datasets/paths/whamr.py b/datasets/paths/whamr.py new file mode 100644 index 0000000..55fe2d0 --- /dev/null +++ b/datasets/paths/whamr.py @@ -0,0 +1,122 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import os +from collections import OrderedDict + + +def get_whamr_paths( + data_path, + task="sep_noisy", + partition="tt", + sr=8000, + use_min=True, +): + POSSIBLE_TASKS = [ + "sep_clean_reverb", + "denoise_both_reverb", + "sep_reverb_all_srcs", + "sep_noisy_reverb_all_srcs", + ] + + assert task in POSSIBLE_TASKS, task + S1_ANECHOIC_DIR = "s1_anechoic" + S2_ANECHOIC_DIR = "s2_anechoic" + S1_REVERB_DIR = "s1_reverb" + S2_REVERB_DIR = "s2_reverb" + S1_DRY = "s1" + S2_DRY = "s2" + BOTH_REVERB_DIR = "mix_both_reverb" + CLEAN_REVERB_DIR = "mix_clean_reverb" + + if task == "sep_reverb_all_srcs": + mix_dir = CLEAN_REVERB_DIR + src_dir_list = [ + S1_ANECHOIC_DIR, + S2_ANECHOIC_DIR, + S1_REVERB_DIR, + S2_REVERB_DIR, + S1_DRY, + S2_DRY, + ] + elif task == "sep_noisy_reverb_all_srcs": + mix_dir = BOTH_REVERB_DIR + src_dir_list = [ + S1_ANECHOIC_DIR, + S2_ANECHOIC_DIR, + S1_REVERB_DIR, + S2_REVERB_DIR, + S1_DRY, + S2_DRY, + ] + + else: + raise ValueError("WHAMR task {} not available, please choose from {}".format(task, POSSIBLE_TASKS)) + return get_wsj2mix_paths( + data_path, + partition=partition, + sr=sr, + use_min=use_min, + mix_dir=mix_dir, + src_dir_list=src_dir_list, + ) + + +def get_wsj2mix_paths( + data_path, + partition="tt", + sr=8000, + use_min=True, + mix_dir=None, + src_dir_list=None, +): + if mix_dir is None: + mix_dir = "mix" + if src_dir_list is None: + src_dir_list = ["s1", "s2"] + if sr == 8000: + wav_dir = "wav8k" + elif sr == 16000: + wav_dir = "wav16k" + else: + raise ValueError("set wsj0-2mix dataset sample rate to either 8kHz or 16kHz") + + if use_min: + max_or_min_dir = "min" + else: + max_or_min_dir = "max" + + root_path = os.path.join(data_path, wav_dir, max_or_min_dir, partition) + filelist = [f for f in os.listdir(os.path.join(root_path, src_dir_list[0])) if f.endswith(".wav")] + if partition == "tr": + assert len(filelist) == 20000, "Expected 20000 files in training set" + elif partition == "cv": + assert len(filelist) == 5000, "Expected 5000 files in validation set" + elif partition == "tt": + assert len(filelist) == 3000, "Expected 3000 files in testing set" + path_list = get_path_list(filelist, root_path, mix_dir, src_dir_list) + + return path_list + + +def get_path_dict(filename, root_dir, mix_dir, src_dir_list): + id_ = os.path.splitext(filename)[0] + path_dict = { + "id": id_, + "srcs": OrderedDict({s: os.path.join(root_dir, s, filename) for s in src_dir_list}), + "mix": os.path.join(root_dir, mix_dir, filename), + } + return path_dict + + +def get_path_list(filelist, root_dir, mix_dir, src_dir_list): + # if type(filelist) == type([]): + if isinstance(filelist, list): + path_list = [get_path_dict(f, root_dir, mix_dir, src_dir_list) for f in filelist] + if isinstance(filelist, dict): + path_list = {} + for key in filelist.keys(): + path_list[key] = [get_path_dict(f, root_dir, mix_dir, src_dir_list) for f in filelist[key]] + return path_list diff --git a/datasets/stft_dataset.py b/datasets/stft_dataset.py new file mode 100755 index 0000000..98edfed --- /dev/null +++ b/datasets/stft_dataset.py @@ -0,0 +1,279 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +from typing import Dict, List, Union + +import numpy as np +import soundfile as sf +import torch + +from utils.audio_utils import do_stft, rotate_channels + + +class STFTDataset(torch.utils.data.Dataset): + def __init__( + self, + path_list: List, + training: bool, + stft_conf: Dict, + sr: int = 8000, + chunk_size: int = None, + chunking_strategy: str = "random", + ref_channel: int = 0, + channel_idx: Union[int, List[int]] = None, + running_eras: bool = False, + normalization: bool = False, + ): + """Dataset class. + + Parameters + ---------- + path_list: List[Dict] + List of the metadata of each audio data. + training: bool + Whether or not this dataset is initialized for training or not. + stft_conf: Dict + STFT configuration. + sr: int + Sampling rate. + chunk_size: int or float + Input length in seconds during training. + chunking_strategy: str + How to make a training chunk. + Choices are random, partial_overlap, or full_overlap. + NOTE: this is considered only for SMS-WSJ. + ref_channel: int + Reference channel index. + channel_idx: int or List[int] + Channel index (indices) of the input data. + To return multi-channel sources, specify the list. + running_eras: bool + Whether the training algorithm is ERAS or not. + ERAS needs a batch which includes both left and right channels in the same mini-batch + for the ICC loss (i.e., batch: [data1_L, data1_R, ..., dataN_L, dataN_R]). + normalization: bool + Whether to apply the variance normalization. + """ + + self.training = training + self.running_eras = running_eras and self.training + if self.running_eras: + print( + f"Use {channel_idx}-th channel(s) as reference microphone(s)", + flush=True, + ) + + self.path_list = path_list + self.sr = sr + + self.chunk_size = int(chunk_size * sr) if self.training else None + + self.ref_channel = ref_channel + + self.stft_conf = stft_conf + self.channel_idx = channel_idx + self.normalization = normalization + + assert chunking_strategy in ["random", "partial_overlap", "full_overlap"] + self.chunking_stragegy = chunking_strategy + + def __len__(self): + return len(self.path_list) + + def __getitem__(self, index): + y_mix, y_srcs = self._read_audio(index) + y_mix = torch.from_numpy(y_mix) + if isinstance(y_srcs, dict): + y_srcs = {tag: torch.from_numpy(y_srcs[tag]) for tag in y_srcs.keys()} + elif y_srcs is not None: + y_srcs = torch.from_numpy(y_srcs) + mix_stft = self._stft(y_mix.T) # (frame, freq, n_chan) + + srcs = dict(y_mix=y_mix, y_srcs=y_srcs) + if self.running_eras: + mix_stft, srcs = self._rotate_and_stack_channels_stft(mix_stft, srcs) + srcs["y_mix_stft"] = mix_stft + return mix_stft, srcs + + def _read_audio(self, index, start_samp=0, end_samp=None): + path_dict = self.path_list[index % len(self.path_list)] + mix_path = path_dict["mix"] + n_frames = sf.info(mix_path).frames + + # chunking, for SMS-WSJ + if self.chunk_size is None or n_frames <= self.chunk_size: + if self.chunking_stragegy == "full_overlap": + offset = max(path_dict["offset"]) # start of shorter utterance + min_length = min(path_dict["num_samples"]) # length of shorter utterance + start_samp, end_samp = offset, min_length + offset + else: + start_samp = 0 + end_samp = None + else: # n_frames > self.frame_size: + # in some partially-overlapped data we need to find fully-overlapped segment + # now we assume number of sources is two + if self.chunking_stragegy == "full_overlap": + assert "num_samples" in path_dict and "offset" in path_dict + offset = max(path_dict["offset"]) # start of shorter utterance + min_length = min(path_dict["num_samples"]) # length of shorter utterance + + if min_length <= self.chunk_size: + start_samp, end_samp = offset, min_length + offset + else: + start_samp = np.random.randint(offset, min_length + offset - self.chunk_size) + end_samp = start_samp + self.chunk_size + + elif self.chunking_stragegy == "partial_overlap": + start_samp, end_samp = self._random_start_and_end(path_dict, min_overlap=self.chunk_size // 4) + + elif self.chunking_stragegy == "random": + start_samp = np.random.randint(0, n_frames - self.chunk_size) + end_samp = start_samp + self.chunk_size + + else: + raise NotImplementedError(self.chunking_stragegy) + + y_mix = self._read_wav(mix_path, start_samp, end_samp) + y_srcs = {"reverb": []} + for tag, src_path in path_dict["srcs"].items(): + if "reverb" in tag: + y_srcs["reverb"].append(self._read_wav(src_path, start_samp, end_samp)) + assert len(y_srcs["reverb"]) > 0, ("reverb", path_dict.keys()) + y_srcs["reverb"] = np.stack(y_srcs["reverb"], axis=-1) + + # normalization: + if self.normalization: + y_mix, y_srcs = self._normalization(y_mix, y_srcs) + + return y_mix, y_srcs + + def _read_wav(self, path, start_samp=0, end_samp=None): + y, sr = sf.read(path, start=start_samp, stop=end_samp, dtype=np.float32) + assert sr == self.sr, f"samplerate of data {sr} does not match requested samplerate {self.sr}" + assert y.ndim == 2 and y.shape[-1] > 1, f"audios must be multi-channel but the shape of {path} is {y.shape}" + + if self.channel_idx is not None: + return y[..., self.channel_idx] + else: + return y + + def _random_start_and_end(self, path_dict, min_overlap=None): + if min_overlap is None: + min_overlap = self.chunk_size // 5 + + assert "num_samples" in path_dict and "offset" in path_dict + offset = max(path_dict["offset"]) # start of shorter utterance + min_length = min(path_dict["num_samples"]) # length of shorter utterance + max_length = max(path_dict["num_samples"]) # mixture length + + # when mixture is shorter than the chunk size we want, + # we simply use the entire utterance + if max_length < self.chunk_size: + return 0, max_length + + # else, we randomly choose partially-overlapped segment + # where at least "min_overlap"-length overlap exists + left = offset + right = left + min_length + min_overlap = min(min_length // 2 - 1, min_overlap) + + assert left + min_overlap < right - min_overlap, (left, right, min_overlap) + if np.random.random() > 0.5: + start = np.random.randint(left + min_overlap, right - min_overlap) + end = start + self.chunk_size + if end > max_length: + start -= end - max_length + end = max_length + else: + end = np.random.randint(left + min_overlap, right - min_overlap) + start = end - self.chunk_size + if start < 0: + start = 0 + end = self.chunk_size + return start, end + + def _normalization(self, y_mix, y_srcs): + # variance normalization + mean = y_mix.mean(keepdims=True) + std = y_mix.std(keepdims=True) + + y_mix = (y_mix - mean) / std + if isinstance(y_srcs, dict): + y_srcs["reverb"] = (y_srcs["reverb"] - mean[None]) / std[None] + elif y_srcs is not None: + y_srcs = (y_srcs - mean[None]) / std[None] + + return y_mix, y_srcs + + def _stft(self, samples): + assert samples.ndim < 4 + # if there is channel dim + if samples.ndim == 3: + n_chan, n_src, n_samples = samples.shape + samples = samples.reshape(-1, n_samples) + reshaped = True + else: + reshaped = False + X = do_stft(samples, **self.stft_conf) + if reshaped: + X = X.reshape(X.shape[:2] + (n_chan, n_src)) + return X + + def _rotate_and_stack_channels_stft(self, feat, srcs): + """Rotate the channel order and stack. + Suppose the input `feat` is (n_frame, n_freq, n_chan) and n_chan==2. + This function stacks `feat` with the channel order of [0, 1] and [1, 0] + and returns (n_frame, n_freq, n_chan, n_chan), which is used in ERAS. + + Returned `feat` is stacked along the batch dim in the `collate_seq_eras` funciton. + """ + + n_chan = len(self.channel_idx) # number of microphone channels + assert feat.shape[-1] == n_chan, feat.shape + + feat_stack = [] + for ref_channel in range(n_chan): + feat_tmp = rotate_channels(feat, ref_channel, channel_dim=-1) + feat_stack.append(feat_tmp) + feat_stack = torch.stack(feat_stack, dim=-1) + + # first initialize the dict to make code cleaner + srcs_stack = {} + for src in srcs: + if isinstance(srcs[src], dict) and src == "y_srcs": + srcs_stack[src] = {} + for key in srcs[src]: + srcs_stack[src][key] = [] + else: + srcs_stack[src] = [] + + for ref_channel in range(n_chan): + for src in srcs: + if isinstance(srcs[src], dict) and src == "y_srcs": + assert srcs[src]["reverb"].shape[-2] == n_chan, ( + src, + key, + srcs[src]["reverb"].shape, + ) + tmp = rotate_channels(srcs[src]["reverb"], ref_channel, channel_dim=-2) + srcs_stack[src]["reverb"].append(tmp) + + # process torch Tensor + elif srcs[src] is not None: + assert srcs[src].shape[-1] == n_chan, ( + src, + srcs[src].shape, + ) + tmp = rotate_channels(srcs[src], ref_channel, channel_dim=-1) + srcs_stack[src].append(tmp) + + # stack along batch dim + for src in srcs_stack: + if isinstance(srcs[src], dict) and src == "y_srcs": + srcs_stack[src]["reverb"] = torch.stack(srcs_stack[src]["reverb"], dim=-1) + elif srcs[src] is not None: + srcs_stack[src] = torch.stack(srcs_stack[src], dim=-1) + + return feat_stack, srcs_stack diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000..2debfe1 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,25 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +name: ras +channels: + - pytorch + - conda-forge + - nvidia +dependencies: + - python=3.11.5 + - pytorch=2.1.1 + - pytorch-cuda=11.8 + - pip + - pip: + - pytorch-lightning==2.1.2 + - loguru + - pyyaml + - SoundFile==0.10.3.post1 + - tensorboard + - protobuf==3.20.3 + - fast-bss-eval + - pesq==0.0.4 + - numpy==1.26.4 diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..0d05405 --- /dev/null +++ b/eval.py @@ -0,0 +1,50 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import argparse +from pathlib import Path + +import loguru +import torch +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.loggers import TensorBoardLogger + +from lightning_train import RASTrainingModule +from utils.config import yaml_to_parser + + +def main(args): + + config_path = args.ckpt_path.parent.parent / "hparams.yaml" + hparams = yaml_to_parser(config_path) + hparams = hparams.parse_args([]) + + seed_everything(0, workers=True) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = True + + exp_name, name, save_dir = [config_path.parents[i].name for i in range(3)] + logger = TensorBoardLogger(save_dir=save_dir, name=name, version=exp_name) + + trainer = Trainer( + logger=logger, + enable_progress_bar=True, + deterministic=True, + devices=1, + num_nodes=1, + ) + # testing + loguru.logger.info("Begin Testing") + model = RASTrainingModule.load_from_checkpoint(args.ckpt_path, hparams=hparams, data_path=args.data_path) + trainer.test(model) + loguru.logger.info("Testing complete") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=Path, required=True) + parser.add_argument("--data_path", type=Path, required=True) + args, other_options = parser.parse_known_args() + main(args) diff --git a/lightning_train.py b/lightning_train.py new file mode 100644 index 0000000..60cc4fb --- /dev/null +++ b/lightning_train.py @@ -0,0 +1,204 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import os +import random +from argparse import Namespace + +import fast_bss_eval +import numpy as np +import pytorch_lightning as pl +import torch +import torch.utils.data as data +from loguru import logger +from pesq import pesq +from pytorch_lightning.utilities import rank_zero_only +from torch import optim + +from datasets.dataset_creator import dataset_creator +from loss_functions.ras_loss import RASLoss +from nets.build_model import build_model +from utils.audio_utils import istft_4dim +from utils.collate import collate_seq, collate_seq_eras + + +class RASTrainingModule(pl.LightningModule): + def __init__(self, hparams, data_path): + super().__init__() + + if not isinstance(hparams, Namespace): + hparams = Namespace(hparams.model_name, **hparams.model_conf) + self.data_path = data_path + + self.save_hyperparameters(hparams) + self.model = build_model(hparams.model_name, hparams.model_conf) + self.loss = RASLoss(**hparams.eras_loss_conf) + + self.current_step = 0 # used for learning-rate warmup + + def load_pretrained_weight(self): + if self.hparams.pretrained_model_path is not None: + if torch.cuda.is_available(): + state_dict = torch.load(self.hparams.pretrained_model_path) + else: + state_dict = torch.load(self.hparams.pretrained_model_path, map_location=torch.device("cpu")) + try: + state_dict = state_dict["state_dict"] + except KeyError: + print("No key named state_dict. Directly loading from model.") + state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} + self.model.load_state_dict(state_dict) + logger.info("Loaded weights from " + self.hparams.pretrained_model_path) + + def on_batch_end(self): + # learning rate warmup + self.warmup_lr() + + @rank_zero_only + def _symlink_logger(self): + # Keep track of which log file goes with which tensorboard log folder + tensorboard_log_dir = self.trainer.logger.log_dir + logger.info(f"Tensorboard logs: {tensorboard_log_dir}") + if os.path.exists(self.hparams.log_file): + _, log_name = os.path.split(self.hparams.log_file) + new_log_path = os.path.join(tensorboard_log_dir, log_name) + + # when resuming training, symlink already exists + if not os.path.islink(new_log_path): + os.symlink(os.path.abspath(self.hparams.log_file), new_log_path) + + def warmup_lr(self): + # get initial learning rate at step 0 + if self.current_step == 0: + for param_group in self.optimizers().optimizer.param_groups: + self.peak_lr = param_group["lr"] + + self.current_step += 1 + if getattr(self.hparams, "warmup_steps", 0) >= self.current_step: + for param_group in self.optimizers().optimizer.param_groups: + param_group["lr"] = self.peak_lr * self.current_step / self.hparams.warmup_steps + + def on_train_start(self): + self._symlink_logger() + self.load_pretrained_weight() + + def forward(self, x): + return self.model(x) + + def _step(self, batch): + input_features, target_dict = batch + input_features, lens = input_features + + y = self.forward(input_features) # (batch, frame, freq) -> (batch, frame, freq, num_src) + + loss = self.loss(y, target_dict, device=self.device, training=self.model.training) + return loss + + def training_step(self, batch, batch_idx): + loss = self._step(batch) + loss_for_logging = {} + for k, v in loss.items(): + loss_for_logging[f"train/{k}"] = v + self.log_dict(loss_for_logging, on_step=True, on_epoch=True, sync_dist=True) + + self.on_batch_end() + return loss["loss"] + + def validation_step(self, batch, batch_idx): + loss = self._step(batch) + loss_for_logging = {} + for k, v in loss.items(): + loss_for_logging[f"val/{k}"] = v + self.log_dict(loss_for_logging, on_epoch=True, sync_dist=True) + + return loss["loss"] + + def test_step(self, batch, batch_idx): + input_features, target_dict = batch + input_features, lens = input_features + sample_rate = self.hparams.dataloading_conf["sr"] + + est = self.forward(input_features) # (batch, frame, freq) -> (batch, frame, freq, num_src) + + # apply FCP + est = self.loss.filtering_func(est, input_features) + est = est[..., self.loss.ref_channel, :] + + # TF-domain -> time-domain by iSTFT + est = istft_4dim(est, **self.loss.stft_conf)[0].T + + # reference signal + ref = target_dict["y_srcs"]["reverb"][0][0, ..., self.loss.ref_channel, :].T + + # compute metrics + m = min(ref.shape[-1], est.shape[-1]) + sisnr, perm = fast_bss_eval.si_sdr(ref[..., :m], est[..., :m], return_perm=True) + sisnr = sisnr.mean().cpu().numpy() + perm = perm.cpu().numpy() + + sdr = fast_bss_eval.sdr(ref, est).mean().cpu().numpy() + + ref, est = ref.cpu().numpy(), est.cpu().numpy() + pesq_score = 0.0 + for i, p in enumerate(perm): + pesq_score += pesq(sample_rate, ref[i], est[p], mode="nb") + pesq_score /= i + 1 + + result = { + "test/sisnr": float(sisnr), + "test/sdr": float(sdr), + "test/pesq": float(pesq_score), + } + self.log_dict(result, on_epoch=True) + + return result + + def configure_optimizers(self): + optimizer = optim.Adam(self.parameters(), **self.hparams.optimizer_conf) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, **self.hparams.scheduler_conf) + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": "val/loss", + } + + def _init_fn(self, worker_id): + random.seed(self.hparams.seed + worker_id) + np.random.seed(self.hparams.seed + worker_id) + torch.manual_seed(self.hparams.seed + worker_id) + + def _get_data_loader(self, partition): + shuffle = self.hparams.shuffle if partition == "tr" else None + if partition == "tr": + batch_size = self.hparams.batch_size + elif partition == "cv": + batch_size = self.hparams.val_batch_size + else: + batch_size = 1 + + d = dataset_creator(self.hparams, self.data_path, partition) + + if getattr(d, "running_eras", False): + collate_fn = collate_seq_eras + else: + collate_fn = collate_seq + + return data.DataLoader( + d, + batch_size, + collate_fn=collate_fn, + shuffle=shuffle, + num_workers=self.hparams.num_workers, + worker_init_fn=self._init_fn, + ) + + def train_dataloader(self): + return self._get_data_loader("tr") + + def val_dataloader(self): + return self._get_data_loader("cv") + + def test_dataloader(self): + return self._get_data_loader("tt") diff --git a/loss_functions/__init__.py b/loss_functions/__init__.py new file mode 100644 index 0000000..8fccc03 --- /dev/null +++ b/loss_functions/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/loss_functions/complex.py b/loss_functions/complex.py new file mode 100644 index 0000000..b484916 --- /dev/null +++ b/loss_functions/complex.py @@ -0,0 +1,65 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +from loss_functions.general import pit_from_pairwise_loss + + +def complex_l1(est, ref, permutation_search=False, reduction="mean_all", **kwargs): + """L1 loss on complex stft representations. + L1 loss on real part, imaginary part, and magnitude are computed. + + Parameters + ---------- + est: torch.Tensor, (n_batch, n_frame, n_freq, n_src) + Estimated (separated) signals. + ref: torch.Tensor, (n_batch, n_frame, n_freq, n_src) + Reference (ground-truth) signals. + permutation_search: bool + Whether to do permutation search between `est` and `ref`. + reduction: str + Argument for controlling the shape of returned tensor. + `keep_batchdim` returns a tensor with (n_batch, ), + `mean_all` returns a tensor with (), + and else, just return a tensor as it is. + + Returns + ---------- + loss: torch.Tensor, + Loss value. + """ + # sometimes ref is a tuple of a tensor and a length + if isinstance(ref, tuple): + ref = ref[0] + normalizer = abs(ref).sum(dim=(1, 2)) # (batch, n_src) + + # expand dimension for getting pairwise loss + if permutation_search: + est = est.unsqueeze(-1) + ref = ref.unsqueeze(-2) + normalizer = normalizer.unsqueeze(-2) + + # loss + real_l1 = abs(ref.real - est.real).sum(dim=(1, 2)) + imag_l1 = abs(ref.imag - est.imag).sum(dim=(1, 2)) + abs_l1 = abs(abs(ref) - abs(est)).sum(dim=(1, 2)) + loss = real_l1 + imag_l1 + abs_l1 + + loss = loss / normalizer + if permutation_search: + # compute loss with brute-force optimal permutation search + loss = pit_from_pairwise_loss(loss) + loss = loss / ref.shape[-1] + + if reduction is None: + return loss + elif reduction == "keep_batchdim": + if loss.ndim == 1: + return loss + else: + return loss.mean(dim=-1) + elif reduction == "mean_all": + return loss.mean() + else: + raise RuntimeError(f"Choose proper reduction: {reduction}") diff --git a/loss_functions/general.py b/loss_functions/general.py new file mode 100644 index 0000000..4ee9cab --- /dev/null +++ b/loss_functions/general.py @@ -0,0 +1,41 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import itertools + +import torch + + +def perms(num_sources): + return list(itertools.permutations(range(num_sources))) + + +def pit_from_pairwise_loss(pw_dist_mat, reduction="sum", return_perm=False): + n_batch, n_src, _ = pw_dist_mat.shape + perm_mat = pw_dist_mat.new_tensor(perms(n_src), dtype=torch.long) # [n_perm, n_src] + ind = pw_dist_mat.new_tensor(range(n_src), dtype=torch.long).unsqueeze(0) # [1, n_src] + expanded_perm_dist_mat = pw_dist_mat[:, ind, perm_mat] # [n_batch, n_perm, n_src] + perm_dist_mat = torch.sum(expanded_perm_dist_mat, dim=2) # [n_batch, n_perm] + min_loss_perm, min_inds = torch.min(perm_dist_mat, dim=1) # [n_batch] + + if return_perm: + opt_perm, perm_mat = torch.broadcast_tensors( + min_inds[:, None, None], perm_mat[None] + ) # [n_batch, n_perm, n_src] + opt_perm = torch.gather(perm_mat, 1, opt_perm[:, [0]])[:, 0] + + if reduction == "sum": + if return_perm: + return min_loss_perm, opt_perm + else: + return min_loss_perm + else: + # sometimes we want (batch x n_src) loss matrix + min_inds = min_inds[:, None, None].tile(1, 1, n_src) + min_loss = torch.gather(expanded_perm_dist_mat, 1, min_inds) + if return_perm: + return min_loss[:, 0, :], opt_perm # (n_batch, n_src) + else: + return min_loss diff --git a/loss_functions/ras_loss.py b/loss_functions/ras_loss.py new file mode 100644 index 0000000..784fab7 --- /dev/null +++ b/loss_functions/ras_loss.py @@ -0,0 +1,295 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +from functools import partial +from typing import Dict + +import torch +from fast_bss_eval import si_sdr + +from loss_functions.complex import complex_l1 +from utils.audio_utils import istft_4dim, stft_3dim +from utils.forward_convolutive_prediction import forward_convolutive_prediction + +loss_funcs = {"complex_l1": complex_l1} + + +class RASLoss(object): + """ + Reverberation as Supervision (RAS) loss [1, 2]. + This class also supports over-determined conditions (UNSSOR) [3]. + + + Parameters + ---------- + loss_func: str + Name of loss function. Only "complex_l1" is supported now. + future_taps: int + The number of future taps in FCP. + past_taps: int + The number of past taps in FCP. + ref_channel_loss_weight: float + Loss weight on the reference channel (input channel of separation model). + In ERAS, 0.0 is recommended. + isms_loss_weight: float + Loss weight of intra-source magnitude scattering (ISMS) loss. + icc_loss_weight: float + Loss weight of inter-channel consistency (ICC) loss. + ref_channel: int + Index of reference channel. + nonref_channel: int + Index of non-reference channel. Must be different from `ref_channel`. + unsupervised: bool + Either of unsupervised or supervised. + supervised_loss_type: str + How to compute supervised loss when doing (semi-)supervised learning. + before_filtering, after_filtering_ref_channel, or after_filtering_nonref_channel. + Supervised loss is computed when doing unsupervised learning on the validation set + to monitor the performance. + stft_conf: dict + Dictionary containing STFT parameters. + + References + ---------- + [1]: Rohith Aralikatti, Christoph Boeddeker, Gordon Wichern, Aswin Subramanian, and Jonathan Le Roux, + "Reverberation as Supervision for Speech Separation," Proc. ICASSP, 2023. + + [2]: Kohei Saijo, Gordon Wichern, François G. Germain, Zexu Pan, and Jonathan Le Roux, + "Enhanced Reverberation as Supervision for Unsupervised Speech Separation," Proc. Interspeech, 2024. + + [3]: Zhong-Qiu Wang and Shinji Watanabe, "UNSSOR: unsupervised neural speech separation + by leveraging over-determined training mixtures," Proc. NeurIPS, 2023. + """ + + def __init__( + self, + loss_func: str = "complex_l1", + future_taps: int = 1, + past_taps: int = 19, + ref_channel_loss_weight: float = 0.0, + isms_loss_weight: float = 0.0, + icc_loss_weight: float = 0.0, + ref_channel: int = 0, + nonref_channel: int = 1, + unsupervised: bool = True, + supervised_loss_type: str = "after_filtering_ref_channel", + stft_conf: Dict = None, + ): + assert loss_func in loss_funcs, loss_func + + self.loss_func = loss_funcs[loss_func] + self.ref_channel = ref_channel + self.nonref_channel = nonref_channel + self.unsupervised = unsupervised + self.supervised_loss_type = supervised_loss_type + self.stft_conf = stft_conf + + # loss weights + self.ref_channel_loss_weight = ref_channel_loss_weight + self.isms_loss_weight = isms_loss_weight + self.icc_loss_weight = icc_loss_weight + + if self.icc_loss_weight > 0: + self.icc_loss = partial(inter_channel_consistency_loss, loss_func=self.loss_func) + + # define FCP used in the forward path + self.filtering_func = partial( + forward_convolutive_prediction, + past_taps=past_taps, + future_taps=future_taps, + ) + + def __call__(self, nn_outputs: torch.Tensor, targets: Dict, **kwargs): + # mixure signal used as supervision in RAS or UNSSOR + mix = targets["y_mix_stft"][0] + + n_batch, n_channels = mix.shape[0], mix.shape[-1] + unsup_loss = 0.0 + + est_src_filtered = self.filtering_func(nn_outputs, mix) # (n_batch, n_frames, n_freqs, n_chan, n_src) + est_mix = est_src_filtered.sum(dim=-1) + + # compute RAS loss + ras_loss = self.loss_func( + est_mix, + mix, + permutation_search=False, + reduction=None, + ) # loss: (n_batch, n_chan) + + assert ras_loss.shape == ( + n_batch, + n_channels, + ), f"loss must be (batch x channel) but {ras_loss.shape}" + + # compute ISMS loss + isms_loss = intra_source_magnitude_scattering_loss(est_src_filtered, mix, reduction=None) + assert ras_loss.shape == isms_loss.shape, "(loss must be (batch x channel)" + f"but {ras_loss.shape} and {isms_loss.shape})" + unsup_loss = ras_loss + self.isms_loss_weight * isms_loss + + # weighting loss on reference channel + unsup_loss[:, self.ref_channel] *= self.ref_channel_loss_weight + unsup_loss = unsup_loss.sum(dim=-1) + + # inter-source consistency loss + training = unsup_loss.requires_grad + + # Currently only one of L or R is loaded during validation + # and we cannot compute inter-source loss on dev set + if self.icc_loss_weight > 0 and training: + icc_loss = self.icc_loss(est_src_filtered) + unsup_loss += self.icc_loss_weight * icc_loss + + # in supervised case or on validation set, we compute the supervised loss + if training and self.unsupervised: + loss = unsup_loss.mean() + sup_loss = torch.zeros_like(unsup_loss) # for logging + else: + # pick some channels for SI-SDR evaluation + est_src_ref_channel = est_src_filtered[..., self.ref_channel, :] + est_src_nonref_channel = est_src_filtered[..., self.nonref_channel, :] + + targets = targets["y_srcs"] + + if self.supervised_loss_type == "before_filtering": + est = nn_outputs + tgt = targets["reverb"][0][..., self.ref_channel, :] + elif self.supervised_loss_type == "after_filtering_ref_channel": + est = est_src_ref_channel + tgt = targets["reverb"][0][..., self.ref_channel, :] + elif self.supervised_loss_type == "after_filtering_nonref_channel": + est = est_src_nonref_channel + tgt = targets["reverb"][0][..., self.nonref_channel, :] + + tgt = stft_3dim(tgt, **self.stft_conf) + + sup_loss = self.loss_func( + est[:, : tgt.shape[1]], + tgt, + permutation_search=True, + reduction="keep_batchdim", + ) + + uloss = unsup_loss.mean() + sloss = sup_loss.mean() + loss = uloss if self.unsupervised else sloss + + # metrics for logging and backprop + metrics = { + "loss": loss, + "ras_loss": ras_loss.mean(), + "isms_loss": isms_loss.mean(), + "unsup_loss": unsup_loss.mean(), + "sup_loss": sup_loss.mean(), + } + + # compute SI-SDR for logging + if not training: + # istft to get time-domain signals when using tf-domain loss + est_src_ref_channel = istft_4dim(est_src_ref_channel, **self.stft_conf).transpose(-1, -2) + est_src_nonref_channel = istft_4dim(est_src_nonref_channel, **self.stft_conf).transpose(-1, -2) + + ref = targets["reverb"][0].transpose(-1, 1) + m = min(est_src_ref_channel.shape[-1], ref.shape[-1]) + metrics["sisnr_ref_channel"] = si_sdr(ref[..., self.ref_channel, :m], est_src_ref_channel[..., :m]).mean() + metrics["sisnr_nonref_channel"] = si_sdr( + ref[..., self.nonref_channel, :m], est_src_nonref_channel[..., :m] + ).mean() + + return metrics + + +def intra_source_magnitude_scattering_loss(est, mix, reduction=None, eps=1e-8): + """Intra source magnitude scattering loss (ISMS) proposed in [3] (Eq.10). + + Parameters + ---------- + est: torch.Tensor, (..., n_frame, n_freq, n_chan, n_src) + Separation estimates AFTER applying FCP + mix: torch.Tensor, (..., n_frame, n_freq, n_chan) + Mixture observed at the microphone + reduction: str + How to aggregate the loss values. + Must be chosen from {None, "keep_batchdim", "mean_all"} + + Returns + ---------- + loss: torch.Tensor + ISMS loss value. Shape depends on the specified reduction. + """ + mix_logmag = torch.log(abs(mix) + eps) + est_logmag = torch.log(abs(est) + eps) + + mix_var = torch.var(mix_logmag, dim=-2).sum(dim=-2) + est_var = torch.var(est_logmag, dim=-3).mean(dim=-1).sum(dim=-2) + + loss = est_var / mix_var # (batch, n_chan) + + if reduction is None: + return loss + elif reduction == "keep_batchdim": + return loss.mean(dim=-1) + elif reduction == "mean_all": + return loss.mean() + else: + raise RuntimeError(f"Choose proper reduction: {reduction}") + + +def inter_channel_consistency_loss(est: torch.Tensor, loss_func: callable): + """Inter-channel consistency (ICC) loss proposed in [2]. + + In ERAS, each component in the mini-batch `est` is: + est[0]: separated signals from mix1 at 1st channel and mapped to [1st, 2nd] channels + est[1]: separated signals from mix1 at 2nd channel and mapped to [2nd, 1st] channels + est[2]: separated signals from mix2 at 1st channel and mapped to [1st, 2nd] channels + est[3]: separated signals from mix2 at 2nd channel and mapped to [2nd, 1st] channels + ... + est[B-2]: separated signals from mix{B/2} at 1st channel and mapped to [1st, 2nd] channels + est[B-1]: separated signals from mix{B/2} at 2nd channel and mapped to [2nd, 1st] channels + + Note that mix{1,...,B/2} are different (totally unrelated) mixtures in the same mini-batch. + + The ICC loss uses signals mapped to the reference microphone as the pseudo-targets of + those mapped to the non-reference microphone, e.g., + - est[0][0] is the pseudo-target of est[1][1] + - est[1][0] is the pseudo-target of est[0][1] + + + Parameters + ---------- + est: torch.Tensor + Separation outputs after FCP (n_batch, n_frames, n_freqs, n_chan, n_src) + loss_func: callable + Loss function + + Returns + ---------- + loss: torch.Tensor, + The inter-channel consistency loss value. + """ + n_batch, n_chan = est.shape[0], est.shape[-2] + assert n_chan == 2 + + # the first channel in `n_chan` dimension is the one mapped to the reference microphone + # while the second is mapped to the non-reference microphone. + ref = est[..., 0, :].clone().detach() # signals mapped to reference microphone + est = est[..., 1, :] # signals mapped to non-reference microphone + + # change the batch order of `est` to [1, 0, 3, 2, ...] + # to align the batch order of ref and est + p_tmp = torch.arange(n_batch) + p = p_tmp.clone() + p[0::2], p[1::2] = p_tmp[1::2], p_tmp[0::2] # p: [1, 0, 3, 2, ...] + est = est[p] + + # compute loss + loss = loss_func( + est, + ref, + permutation_search=True, + reduction=None, + ) + return loss.mean() diff --git a/nets/__init__.py b/nets/__init__.py new file mode 100644 index 0000000..8fccc03 --- /dev/null +++ b/nets/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/nets/build_model.py b/nets/build_model.py new file mode 100644 index 0000000..36f7c71 --- /dev/null +++ b/nets/build_model.py @@ -0,0 +1,15 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +from .tfgridnetv2 import TFGridNetV2 + + +def build_model(model_name, model_conf): + if model_name == "tfgridnetv2": + model = TFGridNetV2(**model_conf) + else: + raise ValueError("Model type {} not currently supported.".format(model_name)) + + return model diff --git a/nets/tfgridnetv2.py b/nets/tfgridnetv2.py new file mode 100644 index 0000000..576ba27 --- /dev/null +++ b/nets/tfgridnetv2.py @@ -0,0 +1,351 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# Copyright (C) 2023 ESPnet Developers +# +# SPDX-License-Identifier: AGPL-3.0-or-later +# SPDX-License-Identifier: Apache-2.0 + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +from torch.nn.parameter import Parameter + +"""This script is adapted from ESPnet (https://github.com/espnet/espnet). +Part of the code is modified for our use. +https://github.com/espnet/espnet/blob/master/espnet2/enh/separator/tfgridnetv2_separator.py +""" + + +class TFGridNetV2(nn.Module): + """Offline TFGridNetV2. Compared with TFGridNet, TFGridNetV2 speeds up the code + by vectorizing multiple heads in self-attention, and better dealing with + Deconv1D in each intra- and inter-block when emb_ks == emb_hs. + + Reference: + [1] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe, + "TF-GridNet: Integrating Full- and Sub-Band Modeling for Speech Separation", + in TASLP, 2023. + [2] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe, + "TF-GridNet: Making Time-Frequency Domain Models Great Again for Monaural + Speaker Separation", in ICASSP, 2023. + + Args: + fft_size: fft size. + n_imics: number of microphone channels (only fixed-array geometry supported). + n_srcs: number of output sources/speakers. + n_layers: number of TFGridNetV2 blocks. + lstm_hidden_units: number of hidden units in LSTM. + attn_n_head: number of heads in self-attention + attn_approx_qk_dim: approximate dimention of frame-level key and value tensors + emb_dim: embedding dimension + emb_ks: kernel size for unfolding and deconv1D + emb_hs: hop size for unfolding and deconv1D + eps: small epsilon for normalization layers. + """ + + def __init__( + self, + fft_size=256, + n_imics=1, + n_srcs=2, + n_layers=4, + lstm_hidden_units=192, + attn_n_head=4, + attn_approx_qk_dim=512, + emb_dim=48, + emb_ks=4, + emb_hs=1, + eps=1.0e-5, + ): + super().__init__() + self.n_srcs = n_srcs + self.n_layers = n_layers + self.n_imics = n_imics + + n_freqs = fft_size // 2 + 1 + + t_ksize = 3 + ks, padding = (t_ksize, 3), (t_ksize // 2, 1) + self.conv = nn.Sequential( + nn.Conv2d(2 * n_imics, emb_dim, ks, padding=padding), + nn.GroupNorm(1, emb_dim, eps=eps), + ) + + self.blocks = nn.ModuleList([]) + for _ in range(n_layers): + self.blocks.append( + GridNetV2Block( + emb_dim, + emb_ks, + emb_hs, + n_freqs, + lstm_hidden_units, + n_head=attn_n_head, + approx_qk_dim=attn_approx_qk_dim, + eps=eps, + ) + ) + + self.deconv = nn.ConvTranspose2d(emb_dim, n_srcs * 2, ks, padding=padding) + + def forward( + self, + input: torch.Tensor, + ) -> torch.Tensor: + """Forward. + + Args: + input (torch.Tensor): batched multi-channel audio tensor with + M audio channels in TF-domain [B, T, F, M] + + Returns: + batch (torch.Tensor): batched monaural audio tensor with + N separated signals in TF-domain [B, T, F, N] + """ + + # using only specified number of channels + batch0 = input[..., : self.n_imics] + + batch = torch.movedim(batch0, 3, 1) # [B, M, T, F] + batch = torch.cat((batch.real, batch.imag), dim=1) # [B, 2*M, T, F] + n_batch, _, n_frames, n_freqs = batch.shape + + batch = self.conv(batch) # [B, -1, T, F] + + for ii in range(self.n_layers): + batch = self.blocks[ii](batch) # [B, -1, T, F] + + batch = self.deconv(batch) # [B, n_srcs*2, T, F] + + batch = batch.view([n_batch, self.n_srcs, 2, n_frames, n_freqs]) + batch = torch.complex(batch[:, :, 0], batch[:, :, 1]) + return torch.movedim(batch, 1, 3) + + @property + def num_spk(self): + return self.n_srcs + + @staticmethod + def pad2(input_tensor, target_len): + input_tensor = torch.nn.functional.pad(input_tensor, (0, target_len - input_tensor.shape[-1])) + return input_tensor + + +class GridNetV2Block(nn.Module): + def __getitem__(self, key): + return getattr(self, key) + + def __init__( + self, + emb_dim, + emb_ks, + emb_hs, + n_freqs, + hidden_channels, + n_head=4, + approx_qk_dim=512, + eps=1e-5, + ): + super().__init__() + + in_channels = emb_dim * emb_ks + + self.intra_norm = nn.LayerNorm(emb_dim, eps=eps) + self.intra_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=True) + if emb_ks == emb_hs: + self.intra_linear = nn.Linear(hidden_channels * 2, in_channels) + else: + self.intra_linear = nn.ConvTranspose1d(hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs) + + self.inter_norm = nn.LayerNorm(emb_dim, eps=eps) + self.inter_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=True) + if emb_ks == emb_hs: + self.inter_linear = nn.Linear(hidden_channels * 2, in_channels) + else: + self.inter_linear = nn.ConvTranspose1d(hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs) + + E = math.ceil(approx_qk_dim * 1.0 / n_freqs) # approx_qk_dim is only approximate + assert emb_dim % n_head == 0 + + self.add_module("attn_conv_Q", nn.Conv2d(emb_dim, n_head * E, 1)) + self.add_module( + "attn_norm_Q", + AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps), + ) + + self.add_module("attn_conv_K", nn.Conv2d(emb_dim, n_head * E, 1)) + self.add_module( + "attn_norm_K", + AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps), + ) + + self.add_module("attn_conv_V", nn.Conv2d(emb_dim, n_head * emb_dim // n_head, 1)) + self.add_module( + "attn_norm_V", + AllHeadPReLULayerNormalization4DCF((n_head, emb_dim // n_head, n_freqs), eps=eps), + ) + + self.add_module( + "attn_concat_proj", + nn.Sequential( + nn.Conv2d(emb_dim, emb_dim, 1), + nn.PReLU(), + LayerNormalization4DCF((emb_dim, n_freqs), eps=eps), + ), + ) + + self.emb_dim = emb_dim + self.emb_ks = emb_ks + self.emb_hs = emb_hs + self.n_head = n_head + + def forward(self, x): + """GridNetV2Block Forward. + + Args: + x: [B, C, T, Q] + out: [B, C, T, Q] + """ + B, C, old_T, old_Q = x.shape + + olp = self.emb_ks - self.emb_hs + T = math.ceil((old_T + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks + Q = math.ceil((old_Q + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks + + x = x.permute(0, 2, 3, 1) # [B, old_T, old_Q, C] + x = F.pad(x, (0, 0, olp, Q - old_Q - olp, olp, T - old_T - olp)) # [B, T, Q, C] + + # intra RNN + input_ = x + intra_rnn = self.intra_norm(input_) # [B, T, Q, C] + if self.emb_ks == self.emb_hs: + intra_rnn = intra_rnn.view([B * T, -1, self.emb_ks * C]) # [BT, Q//I, I*C] + intra_rnn, _ = self.intra_rnn(intra_rnn) # [BT, Q//I, H] + intra_rnn = self.intra_linear(intra_rnn) # [BT, Q//I, I*C] + intra_rnn = intra_rnn.view([B, T, Q, C]) + else: + intra_rnn = intra_rnn.view([B * T, Q, C]) # [BT, Q, C] + intra_rnn = intra_rnn.transpose(1, 2) # [BT, C, Q] + intra_rnn = F.unfold(intra_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)) # [BT, C*I, -1] + intra_rnn = intra_rnn.transpose(1, 2) # [BT, -1, C*I] + + intra_rnn, _ = self.intra_rnn(intra_rnn) # [BT, -1, H] + + intra_rnn = intra_rnn.transpose(1, 2) # [BT, H, -1] + intra_rnn = self.intra_linear(intra_rnn) # [BT, C, Q] + intra_rnn = intra_rnn.view([B, T, C, Q]) + intra_rnn = intra_rnn.transpose(-2, -1) # [B, T, Q, C] + intra_rnn = intra_rnn + input_ # [B, T, Q, C] + + intra_rnn = intra_rnn.transpose(1, 2) # [B, Q, T, C] + + # inter RNN + input_ = intra_rnn + inter_rnn = self.inter_norm(input_) # [B, Q, T, C] + if self.emb_ks == self.emb_hs: + inter_rnn = inter_rnn.view([B * Q, -1, self.emb_ks * C]) # [BQ, T//I, I*C] + inter_rnn, _ = self.inter_rnn(inter_rnn) # [BQ, T//I, H] + inter_rnn = self.inter_linear(inter_rnn) # [BQ, T//I, I*C] + inter_rnn = inter_rnn.view([B, Q, T, C]) + else: + inter_rnn = inter_rnn.view(B * Q, T, C) # [BQ, T, C] + inter_rnn = inter_rnn.transpose(1, 2) # [BQ, C, T] + inter_rnn = F.unfold(inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)) # [BQ, C*I, -1] + inter_rnn = inter_rnn.transpose(1, 2) # [BQ, -1, C*I] + + inter_rnn, _ = self.inter_rnn(inter_rnn) # [BQ, -1, H] + + inter_rnn = inter_rnn.transpose(1, 2) # [BQ, H, -1] + inter_rnn = self.inter_linear(inter_rnn) # [BQ, C, T] + inter_rnn = inter_rnn.view([B, Q, C, T]) + inter_rnn = inter_rnn.transpose(-2, -1) # [B, Q, T, C] + inter_rnn = inter_rnn + input_ # [B, Q, T, C] + + inter_rnn = inter_rnn.permute(0, 3, 2, 1) # [B, C, T, Q] + + inter_rnn = inter_rnn[..., olp : olp + old_T, olp : olp + old_Q] + batch = inter_rnn + + Q = self["attn_norm_Q"](self["attn_conv_Q"](batch)) # [B, n_head, C, T, Q] + K = self["attn_norm_K"](self["attn_conv_K"](batch)) # [B, n_head, C, T, Q] + V = self["attn_norm_V"](self["attn_conv_V"](batch)) # [B, n_head, C, T, Q] + Q = Q.view(-1, *Q.shape[2:]) # [B*n_head, C, T, Q] + K = K.view(-1, *K.shape[2:]) # [B*n_head, C, T, Q] + V = V.view(-1, *V.shape[2:]) # [B*n_head, C, T, Q] + + Q = Q.transpose(1, 2) + Q = Q.flatten(start_dim=2) # [B', T, C*Q] + + K = K.transpose(2, 3) + K = K.contiguous().view([B * self.n_head, -1, old_T]) # [B', C*Q, T] + + V = V.transpose(1, 2) # [B', T, C, Q] + old_shape = V.shape + V = V.flatten(start_dim=2) # [B', T, C*Q] + emb_dim = Q.shape[-1] + + attn_mat = torch.matmul(Q, K) / (emb_dim**0.5) # [B', T, T] + attn_mat = F.softmax(attn_mat, dim=2) # [B', T, T] + V = torch.matmul(attn_mat, V) # [B', T, C*Q] + + V = V.reshape(old_shape) # [B', T, C, Q] + V = V.transpose(1, 2) # [B', C, T, Q] + emb_dim = V.shape[1] + + batch = V.contiguous().view([B, self.n_head * emb_dim, old_T, old_Q]) # [B, C, T, Q]) + batch = self["attn_concat_proj"](batch) # [B, C, T, Q]) + + out = batch + inter_rnn + return out + + +class LayerNormalization4DCF(nn.Module): + def __init__(self, input_dimension, eps=1e-5): + super().__init__() + assert len(input_dimension) == 2 + param_size = [1, input_dimension[0], 1, input_dimension[1]] + self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32)) + self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32)) + init.ones_(self.gamma) + init.zeros_(self.beta) + self.eps = eps + + def forward(self, x): + if x.ndim == 4: + stat_dim = (1, 3) + else: + raise ValueError("Expected x to have 4 dimensions, but got {}".format(x.ndim)) + mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,1] + std_ = torch.sqrt(x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps) # [B,1,T,F] + x_hat = ((x - mu_) / std_) * self.gamma + self.beta + return x_hat + + +class AllHeadPReLULayerNormalization4DCF(nn.Module): + def __init__(self, input_dimension, eps=1e-5): + super().__init__() + assert len(input_dimension) == 3 + H, E, n_freqs = input_dimension + param_size = [1, H, E, 1, n_freqs] + self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32)) + self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32)) + init.ones_(self.gamma) + init.zeros_(self.beta) + self.act = nn.PReLU(num_parameters=H, init=0.25) + self.eps = eps + self.H = H + self.E = E + self.n_freqs = n_freqs + + def forward(self, x): + assert x.ndim == 4 + B, _, T, _ = x.shape + x = x.view([B, self.H, self.E, T, self.n_freqs]) + x = self.act(x) # [B,H,E,T,F] + stat_dim = (2, 4) + mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,H,1,T,1] + std_ = torch.sqrt(x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps) # [B,H,1,T,1] + x = ((x - mu_) / std_) * self.gamma + self.beta # [B,H,E,T,F] + return x diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..c8b4a2c --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,8 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +pre-commit +black>=22 +flake8 +pytest diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..8fccc03 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/tests/test_fcp.py b/tests/test_fcp.py new file mode 100644 index 0000000..6c5e157 --- /dev/null +++ b/tests/test_fcp.py @@ -0,0 +1,33 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import pytest +import torch + +from utils.forward_convolutive_prediction import forward_convolutive_prediction as fcp +from utils.forward_convolutive_prediction import stack_past_and_future_taps + + +@pytest.mark.parametrize("past_tap", [1, 5, 19]) +@pytest.mark.parametrize("future_tap", [0, 1]) +def test_stack_past_and_future_taps_forward(past_tap, future_tap): + n_batch, n_frame, n_freq, n_src = 1, 50, 65, 2 + input = torch.randn((n_batch, n_frame, n_freq, n_src), dtype=torch.complex64) + + padded = stack_past_and_future_taps(input, past_tap, future_tap) + assert padded.shape == (n_batch, n_frame, past_tap + future_tap + 1, n_freq, n_src) + + +@pytest.mark.parametrize("past_tap", [1, 5, 19]) +@pytest.mark.parametrize("future_tap", [0, 1]) +@pytest.mark.parametrize("n_chan", [1, 2, 6]) +@pytest.mark.parametrize("n_src", [1, 2]) +def test_fcp_forward(past_tap, future_tap, n_chan, n_src): + n_batch, n_frame, n_freq = 1, 50, 65 + est = torch.randn((n_batch, n_frame, n_freq, n_src), dtype=torch.complex64) + mix = torch.randn((n_batch, n_frame, n_freq, n_chan), dtype=torch.complex64) + + est_filtered = fcp(est, mix, past_tap, future_tap) + assert est_filtered.shape == (n_batch, n_frame, n_freq, n_chan, n_src) diff --git a/tests/test_loss.py b/tests/test_loss.py new file mode 100644 index 0000000..36db6a3 --- /dev/null +++ b/tests/test_loss.py @@ -0,0 +1,70 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import pytest +import torch + +from loss_functions.complex import complex_l1 +from loss_functions.ras_loss import inter_channel_consistency_loss as icc_loss +from loss_functions.ras_loss import intra_source_magnitude_scattering_loss as isms_loss + + +@pytest.mark.parametrize("permutation_search", [True, False]) +@pytest.mark.parametrize("reduction", ["mean_all", "keep_batchdim", None]) +def test_complex_l1(permutation_search, reduction): + n_batch, n_frame, n_freq, n_src = 1, 50, 65, 2 + est = torch.randn((n_batch, n_frame, n_freq, n_src), dtype=torch.complex64) + ref = torch.randn((n_batch, n_frame, n_freq, n_src), dtype=torch.complex64) + + loss = complex_l1(est, ref, permutation_search=permutation_search, reduction=reduction) + + if reduction is None: + if permutation_search: + assert loss.shape == (n_batch,) + else: + assert loss.shape == (n_batch, n_src) + elif reduction == "keep_batchdim": + assert loss.shape == (n_batch,) + else: + assert loss.shape == torch.Size([]) + + +@pytest.mark.parametrize("reduction", ["mean_all", "keep_batchdim", None]) +def test_isms_loss(reduction): + n_batch, n_frame, n_freq, n_chan, n_src = 1, 50, 65, 2, 2 + est = torch.randn((n_batch, n_frame, n_freq, n_chan, n_src), dtype=torch.complex64) + mix = torch.randn((n_batch, n_frame, n_freq, n_chan), dtype=torch.complex64) + + loss = isms_loss(est, mix, reduction=reduction) + + if reduction is None: + assert loss.shape == (n_batch, n_chan) + elif reduction == "keep_batchdim": + assert loss.shape == (n_batch,) + else: + assert loss.shape == torch.Size([]) + + +def test_icc_loss(): + n_batch, n_frame, n_freq, n_chan, n_src = 1, 50, 65, 2, 2 + est_ch1 = torch.randn((n_batch, n_frame, n_freq, n_chan, n_src), dtype=torch.complex64) + est_ch2 = torch.stack((est_ch1[..., 1, :], est_ch1[..., 0, :]), dim=-2) + est = torch.cat((est_ch1, est_ch2), dim=0) + + loss = icc_loss(est, complex_l1) + + assert loss.shape == torch.Size([]) + assert loss == 0.0 + + +@pytest.mark.parametrize("n_chan", [1, 3, 6]) +def test_icc_loss_invalid_type(n_chan): + # n_chan != 2 raises the assertion error + + n_batch, n_frame, n_freq, n_src = 2, 50, 65, 2 + est = torch.randn((n_batch, n_frame, n_freq, n_chan, n_src), dtype=torch.complex64) + + with pytest.raises(AssertionError): + icc_loss(est, complex_l1) diff --git a/tests/test_tfgridnet.py b/tests/test_tfgridnet.py new file mode 100644 index 0000000..9fb1c95 --- /dev/null +++ b/tests/test_tfgridnet.py @@ -0,0 +1,60 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import pytest +import torch + +from nets.tfgridnetv2 import TFGridNetV2 + + +@pytest.mark.parametrize("fft_size", [128, 256]) +@pytest.mark.parametrize("n_srcs", [1, 2]) +@pytest.mark.parametrize("n_imics", [1, 3, 6]) +@pytest.mark.parametrize("n_layers", [1, 4, 6]) +@pytest.mark.parametrize("lstm_hidden_units", [16]) +@pytest.mark.parametrize("attn_n_head", [1, 4]) +@pytest.mark.parametrize("attn_approx_qk_dim", [32]) +@pytest.mark.parametrize("emb_dim", [16]) +@pytest.mark.parametrize("emb_ks", [4]) +@pytest.mark.parametrize("emb_hs", [1, 4]) +@pytest.mark.parametrize("eps", [1.0e-5]) +def test_tfgridnetv2_forward_backward( + fft_size, + n_srcs, + n_imics, + n_layers, + lstm_hidden_units, + attn_n_head, + attn_approx_qk_dim, + emb_dim, + emb_ks, + emb_hs, + eps, +): + + model = TFGridNetV2( + fft_size=fft_size, + n_imics=n_imics, + n_srcs=n_srcs, + n_layers=n_layers, + lstm_hidden_units=lstm_hidden_units, + attn_n_head=attn_n_head, + attn_approx_qk_dim=attn_approx_qk_dim, + emb_dim=emb_dim, + emb_ks=emb_ks, + emb_hs=emb_hs, + eps=eps, + ) + model.train() + + n_freqs = fft_size // 2 + 1 + n_batch, n_frames = 2, 18 + real = torch.rand(n_batch, n_frames, n_freqs, n_imics) + imag = torch.rand(n_batch, n_frames, n_freqs, n_imics) + x = torch.complex(real, imag) + + output = model(x) + assert output.shape == (n_batch, n_frames, n_freqs, n_srcs) + sum(output).abs().mean().backward() diff --git a/train.py b/train.py new file mode 100644 index 0000000..5f75d73 --- /dev/null +++ b/train.py @@ -0,0 +1,81 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import argparse +from pathlib import Path + +import loguru +import torch +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger + +from lightning_train import RASTrainingModule +from utils.config import yaml_to_parser + + +def main(args): + + hparams = yaml_to_parser(args.config) + hparams = hparams.parse_args([]) + exp_name = args.config.stem + + seed_everything(hparams.seed, workers=True) + + # some cuda configs + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = True + + logger = TensorBoardLogger(save_dir="exp", name="eras", version=exp_name) + ckpt_dir = Path(logger.log_dir) / "checkpoints" + + model = RASTrainingModule(hparams, args.data_path) + + if (ckpt_dir / "last.ckpt").exists(): + # resume training from the latest checkpoint + ckpt_path = ckpt_dir / "last.ckpt" + skip_first_validation_loop = True + loguru.logger.info(f"Resume training from {str(ckpt_path)}") + elif getattr(hparams, "pretrained_model_path", None) is not None: + ckpt_path = None + skip_first_validation_loop = True + else: + print("Train from scratch") + ckpt_path = None + skip_first_validation_loop = False + + ckpt_callback = ModelCheckpoint(**hparams.model_checkpoint) + callbacks = [LearningRateMonitor(logging_interval="epoch"), ckpt_callback] + if hparams.early_stopping is not None: + callbacks.append(EarlyStopping(**hparams.early_stopping)) + + trainer = Trainer( + logger=logger, + callbacks=callbacks, + enable_progress_bar=False, + deterministic=True, + devices=-1, + strategy="ddp", + **hparams.trainer_conf, + ) + + # validation epoch before training for debugging + if skip_first_validation_loop: + loguru.logger.info("Skip validating before train when resuming training") + else: + loguru.logger.info("Validating before train") + trainer.validate(model) + + # training + loguru.logger.info("Finished initial validation") + trainer.fit(model, ckpt_path=ckpt_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=Path, required=True) + parser.add_argument("--data_path", type=Path, required=True) + args, other_options = parser.parse_known_args() + main(args) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..8fccc03 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/utils/audio_utils.py b/utils/audio_utils.py new file mode 100755 index 0000000..b932d15 --- /dev/null +++ b/utils/audio_utils.py @@ -0,0 +1,231 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import torch +import torch.nn.functional as F +from scipy import signal + + +def rotate_channels(input, ref_channel, channel_dim=-1): + """ + Rotate channel dimension so that the given ref_channel comes first + """ + output = input.transpose(-1, channel_dim) + output = torch.cat((output[..., ref_channel:], output[..., :ref_channel]), dim=-1) + output = output.transpose(-1, channel_dim) + assert input.shape == output.shape, (input.shape, output.shape) + return output + + +def stft_padding( + input_signal, + n_fft, + window_length, + hop_length, + center_pad=True, + end_pad=True, + pad_mode="constant", +): + signal_length = input_signal.shape[-1] + pad_start = 0 + pad_end = 0 + if center_pad: + # Do center padding here instead of torch.stft, since we need to do it anyway because they don't + # support end padding. + pad_start = int(n_fft // 2) + pad_end = pad_start + signal_length = signal_length + pad_start + pad_end + if end_pad: + # from scipy.signal.stft + # Pad to integer number of windowed segments + # I.e. make signal_length = window_length + (nseg-1)*hop_length, with integer nseg + nadd = (-(signal_length - window_length) % hop_length) % window_length + pad_end += nadd + + # do the padding + signal_dim = input_signal.dim() + extended_shape = [1] * (3 - signal_dim) + list(input_signal.size()) + input_signal = F.pad(input_signal.view(extended_shape), (pad_start, pad_end), pad_mode) + input_signal = input_signal.view(input_signal.shape[-signal_dim:]) + return input_signal + + +def _normalize_options(window, method_str): + normalize_flag = False + if method_str == "window": + window = window / torch.sum(window) + elif method_str == "default": + normalize_flag = True + return window, normalize_flag + + +def do_stft( + signal, + window_length, + hop_length, + fft_size=None, + normalize="default", + window_type="sqrt_hann", +): + """ + Wrap torch.stft, and return transposed spectrogram for pytorch input + + :param signal: tensor of shape (n_channels, n_samples) or (n_samples) + :param window_length: int size of stft window + :param hop_length: int stride of stft window + :param fft_size: int geq to window_length, if None set to window_length + :param normalize: string for determining how to normalize stft output. + "window": divide the window by its sum. This will give amplitudes of components that match + time domain components, but small values could cause numerical issues + "default": default pytorch stft normalization divides by sqrt of window length + None: no normalization of stft outputs + :param window_type: string indicating the window type ("sqrt_hann", "hann", "blackman", "blackmanharris", "hamming") + :return: complex tensor of shape (n_frames, n_frequencies, n_channels) or (n_frames, n_frequencies) + """ + if fft_size is None: + fft_size = window_length + signal = stft_padding(signal, fft_size, window_length, hop_length) + window = get_window(window=window_type, window_length=window_length, device=signal.device) + window, normalize_flag = _normalize_options(window, normalize) + result = torch.stft( + signal, + n_fft=fft_size, + hop_length=hop_length, + win_length=window_length, + window=window, + normalized=normalize_flag, + center=False, + return_complex=True, + ) + return result.transpose(0, -1) # transpose output so time axis is first to better match how torch lstm takes input + + +def do_istft( + stft, + window_length=None, + hop_length=None, + fft_size=None, + normalize="default", + window_type="sqrt_hann", +): + """ + Wrap torch.istft and return time domain signal + + :param stft: complex tensor of shape (n_frames, n_frequencies, n_channels) or (n_frames, n_frequencies) + :param window_length: int size of stft window + :param hop_length: int stride of stft window + :param fft_size: int geq to window_length, if None set to window_length + :param normalize: string for determining how to normalize stft output. + "window": divide the window by its sum. This will give amplitudes of components that match + time domain components, but small values could cause numerical issues + "default": default pytorch stft normalization divides by sqrt of window length + None: no normalization of stft outputs + :param window_type: string indicating the window type ("sqrt_hann", "hann", "blackman", "blackmanharris", "hamming") + :return: tensor of shape (n_samples, n_channels) or (n_samples) + """ + + window = get_window(window=window_type, window_length=window_length, device=stft.device) + if fft_size is None: + fft_size = window_length + window, normalize_flag = _normalize_options(window, normalize) + # Must have center=True to satisfy OLA constraints + signal = torch.istft( + stft.transpose(0, -1), + fft_size, + hop_length=hop_length, + win_length=window_length, + window=window, + center=True, + normalized=normalize_flag, + ) + return signal.T # to be consistent with stft, we have sources in the last dimension + + +def get_window(window="sqrt_hann", window_length=1024, device=None): + if window == "sqrt_hann": + return sqrt_hann(window_length, device=device) + elif window == "hann": + return torch.hann_window(window_length, periodic=True, device=device) + elif window in ["blackman", "blackmanharris", "hamming"]: + return torch.Tensor(signal.get_window(window, window_length), device=device) + + +def sqrt_hann(window_length, device=None): + """Implement a sqrt-Hann window""" + return torch.sqrt(torch.hann_window(window_length, periodic=True, device=device)) + + +def get_num_fft_bins(fft_size): + return fft_size // 2 + 1 + + +def get_padded_stft_frames(signal_lens, window_length, hop_length, fft_size): + """ + Compute stft frame lengths taking into account the padding used by our stft code + + :param signal_lens: torch tensor of signal waveform lengths + :param window_length: scalar stft window length + :param hop_length: scalar stft hop length + :param fft_size: scalar stft fft size + :return: torch tensor of stft frame lengths + """ + added_padding = 2 * int(fft_size // 2) + return torch.ceil((signal_lens + added_padding - window_length) / hop_length + 1) + + +def stft_3dim(input, **kwargs): + """Apply STFT for 3-dim tensor one by one + + Parameters + ---------- + input: torch.Tensor, (n_batch, n_samples, n_src) + Input time-domain waveform + + Returns + ---------- + output: torch.Tensor, (n_batch, n_frame, n_freq, n_src): + Output STFT-domain spectrogram + """ + assert input.ndim == 3, input.shape + output = [ + do_stft( + input[..., i], + kwargs["window_length"], + kwargs["hop_length"], + kwargs["fft_size"], + kwargs["normalize"], + ) + for i in range(input.shape[-1]) + ] + output = torch.stack(output, dim=-1).movedim(2, 0) + return output + + +def istft_4dim(input, **kwargs): + """Apply iSTFT for 4-dim tensor one by one + + Parameters + ---------- + input: torch.Tensorm (n_batch, n_frame, n_freq, n_src) + Input STFT-domain spectrogram + + Returns + ---------- + output: torch.Tensor, (n_batch, n_samples, n_src) + Output time-domain waveform + """ + assert input.ndim == 4, input.shape + output = [ + do_istft( + input[i], + kwargs["window_length"], + kwargs["hop_length"], + kwargs["fft_size"], + kwargs["normalize"], + ) + for i in range(input.shape[0]) + ] + output = torch.stack(output) + return output diff --git a/utils/collate.py b/utils/collate.py new file mode 100644 index 0000000..6c0662e --- /dev/null +++ b/utils/collate.py @@ -0,0 +1,59 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import collections.abc as container_abcs + +import torch +import torch.nn.utils.rnn as rnn + + +def collate_seq(batch): + elem = batch[0] + elem_type = type(elem) + if elem_type.__name__ == "ndarray": + lengths = torch.tensor([len(b) for b in batch]) + pad_features = rnn.pad_sequence([torch.tensor(b, dtype=torch.float32) for b in batch], batch_first=True) + return pad_features, lengths + + elif isinstance(elem, torch.Tensor): + lengths = torch.tensor([len(b) for b in batch]) + pad_features = rnn.pad_sequence([b for b in batch], batch_first=True) + return pad_features, lengths + + elif isinstance(elem, container_abcs.Sequence): + transposed = zip(*batch) + return [collate_seq(samples) for samples in transposed] + + elif isinstance(elem, container_abcs.Mapping): + return {key: collate_seq([d[key] for d in batch]) for key in elem} + + else: + # for other stuff just return it and do not collate + return [b for b in batch] + + +def collate_seq_eras(batch): + elem = batch[0] + elem_type = type(elem) + if elem_type.__name__ == "ndarray": + raise RuntimeError("Input must be torch Tensor or Dict") + + elif isinstance(elem, torch.Tensor): + lengths = torch.tensor([len(b) for b in batch]) + pad_features = rnn.pad_sequence([b for b in batch], batch_first=True) + pad_features = pad_features.movedim(-1, 1) + pad_features = pad_features.reshape((-1,) + pad_features.shape[2:]) + return pad_features, lengths + + elif isinstance(elem, container_abcs.Sequence): + transposed = zip(*batch) + return [collate_seq_eras(samples) for samples in transposed] + + elif isinstance(elem, container_abcs.Mapping): + return {key: collate_seq_eras([d[key] for d in batch]) for key in elem} + + else: + # for other stuff just return it and do not collate + return [b for b in batch] diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 0000000..b45540c --- /dev/null +++ b/utils/config.py @@ -0,0 +1,43 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import argparse + +import yaml + + +def _val_to_argparse_kwargs(val): + if isinstance(val, str): + return {"type": strings_with_none, "default": val} + elif isinstance(val, bool): + return {"type": bool_string_to_bool, "default": val} + else: + return {"type": eval, "default": val} + + +def strings_with_none(arg_str): + if arg_str.lower() in ["null", "none"]: + return None + else: + return str(arg_str) + + +def bool_string_to_bool(bool_str): + if str(bool_str).lower() == "false": + return False + elif str(bool_str).lower() == "true": + return True + else: + raise argparse.ArgumentTypeError('For boolean args, use "True" or "False" strings, not {}'.format(bool_str)) + + +def yaml_to_parser(yaml_path): + default_hyperparams = yaml.safe_load(open(yaml_path)) + parser = argparse.ArgumentParser() + + for k, v in default_hyperparams.items(): + argparse_kwargs = _val_to_argparse_kwargs(v) + parser.add_argument("--{}".format(k.replace("_", "-")), **argparse_kwargs) + return parser diff --git a/utils/forward_convolutive_prediction.py b/utils/forward_convolutive_prediction.py new file mode 100644 index 0000000..65ec99a --- /dev/null +++ b/utils/forward_convolutive_prediction.py @@ -0,0 +1,116 @@ +# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import torch + + +def forward_convolutive_prediction( + est: torch.Tensor, + ref: torch.Tensor, + past_taps: int, + future_taps: int, + eps: float = 1e-8, +): + """Forward convolutive prediction (FCP) proposed in [1]. + + Parameters + ---------- + est: torch.Tensor, (n_batch, n_frame, n_freq, n_src) + Signals to which the FCP is applied. + ref: torch.Tensor, (n_batch, n_frame, n_freq, n_chan) + Reference signals of the FCP. + past_taps: int + The number of the past taps in the FCP. + future_taps: int + The number of the future taps in the FCP. + eps: float + Stabilizer for matrix inverse. + + Returns + ---------- + output: torch.Tensor, (n_batch, n_frame, n_freq, n_chan, n_src) + Signals after applying the FCP. + + References + ---------- + [1]: Z.-Q Wang, G. Wichern, and J. Le Roux, + "Convolutive Prediction for Monaural Speech Dereverberation and Noisy-Reverberant Speaker Separation," + IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 29, pp. 3476-3490, 2021. + """ + + n_chan = ref.shape[-1] + # stack past and future frames first + est_tilde = stack_past_and_future_taps(est, past_taps, future_taps) + + # compute weighting term lambda + weight = compute_fcp_weight(ref) + weight = weight[..., None, None] + + # compute FCP filter + # get auto-covariance matrix + auto_cov = torch.einsum("...tafn, ...tbfn -> ...tfnab", est_tilde, est_tilde.conj()) + auto_cov = (auto_cov / weight).sum(dim=-5) + auto_cov = auto_cov.unsqueeze(-4).tile(1, 1, n_chan, 1, 1, 1) + + # get cross-covariance matrix + cross_cov = torch.einsum("...tafn, ...tfc -> ...tfcna", est_tilde, ref.conj()) + cross_cov = (cross_cov / weight).sum(dim=-5) + + # compute relative RIR: (batch, past+1+future, freq, n_chan, n_src) + rir = torch.linalg.solve(auto_cov + eps, cross_cov) + + # filter the estimate: (batch, frame, freq, n_chan, n_src) + output = torch.einsum("...fcna, ...tafn -> ...tfcn", rir.conj(), est_tilde) + + return output + + +def stack_past_and_future_taps( + input, + past_tap, + future_tap, +): + """Function to stack the past and future frames of the input. + + Parameters + ---------- + input: torch.Tensor, (n_batch, n_frame, n_freq, n_src) + Signals to which the FCP is applied. + past_taps: int + Number of past taps in the FCP. + future_taps: int + Number of future taps in the FCP. + + Returns + ---------- + output: torch.Tensor, (n_batch, n_frame, past_tap+1+future_tap, n_freq, n_src) + A tensor that stacks the past and future frames of the input + """ + + T = input.shape[-3] + indices = torch.arange(0, past_tap + future_tap + 1).view(1, 1, past_tap + future_tap + 1).to(input.device) + padded_indices = torch.arange(T).view(1, T, 1).to(input.device) + indices + output = torch.nn.functional.pad(input, (0, 0, 0, 0, past_tap, future_tap)) + output = output[:, padded_indices] + return output[:, 0] + + +def compute_fcp_weight(input, eps=1e-4): + """Function to calculate the weighting term (lambda) in the FCP [1]. + + Parameters + ---------- + input: torch.Tensor, (n_batch, n_frame, n_freq, n_chan) + Signal to compute the FCP weight. + + Returns + ---------- + weight: torch.Tensor, (n_batch, n_frame, n_freq, n_chan) + FCP weighting term with the same shape as the input. + """ + power = (abs(input) ** 2).mean(dim=-1) + max_power = torch.max(power) + weight = power + eps * max_power + return weight.unsqueeze(-1)