diff --git a/.github/workflows/build_and_test.yaml b/.github/workflows/build_and_test.yaml
new file mode 100644
index 0000000..fb92e1c
--- /dev/null
+++ b/.github/workflows/build_and_test.yaml
@@ -0,0 +1,60 @@
+# Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+name: Build and Test
+
+on:
+ pull_request:
+ push:
+ branches:
+ - '**'
+ tags-ignore:
+ - '**'
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ defaults:
+ run:
+ shell: bash -l {0}
+
+ steps:
+ - name: Checkout repo
+ uses: actions/checkout@v4
+
+ - name: Create environment with miniconda
+ uses: conda-incubator/setup-miniconda@v3
+ with:
+ miniforge-variant: Mambaforge
+ miniforge-version: latest
+ use-mamba: true
+ activate-environment: ras
+
+ - name: Get Date
+ id: get-date
+ run: echo "today=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_OUTPUT
+ shell: bash
+
+ - name: Cache Conda env
+ uses: actions/cache@v3
+ with:
+ path: ${{ env.CONDA }}/envs
+ key:
+ conda-${{ runner.os }}--${{ runner.arch }}--${{
+ steps.get-date.outputs.today }}-${{
+ hashFiles('environment.yaml') }}-${{ env.CACHE_NUMBER
+ }}
+ env:
+ # Increase this value to reset cache if etc/example-environment.yml has not changed
+ CACHE_NUMBER: 0
+ id: cache
+
+ - name: Update environment
+ run: mamba env update -n ras -f environment.yaml
+ if: steps.cache.outputs.cache-hit != 'true'
+
+ - name: Run unit tests
+ run: |
+ pip install pytest
+ pytest
diff --git a/.github/workflows/static_checks.yaml b/.github/workflows/static_checks.yaml
new file mode 100644
index 0000000..2db7f87
--- /dev/null
+++ b/.github/workflows/static_checks.yaml
@@ -0,0 +1,77 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+name: Static code checks
+
+on:
+ pull_request:
+ push:
+ branches:
+ - '**'
+ tags-ignore:
+ - '**'
+
+env:
+ LICENSE: AGPL-3.0-or-later
+ FETCH_DEPTH: 1
+ FULL_HISTORY: 0
+ SKIP_WORD_PRESENCE_CHECK: 0
+
+jobs:
+ static-code-check:
+ if: endsWith(github.event.repository.name, 'private')
+
+ name: Run static code checks
+ # See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu1804-Readme.md for list of packages
+ runs-on: ubuntu-latest
+ defaults:
+ run:
+ shell: bash -l {0}
+
+ steps:
+ - name: Setup history
+ if: github.ref == 'refs/heads/oss'
+ run: |
+ echo "FETCH_DEPTH=0" >> $GITHUB_ENV
+ echo "FULL_HISTORY=1" >> $GITHUB_ENV
+
+ - name: Setup version
+ if: github.ref == 'refs/heads/melco'
+ run: |
+ echo "SKIP_WORD_PRESENCE_CHECK=1" >> $GITHUB_ENV
+
+ - name: Check out code
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: ${{ env.FETCH_DEPTH }} # '0' to check full history
+
+ - name: Set up environment
+ run: git config user.email github-bot@merl.com
+
+ - name: Set up python
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.11'
+ cache: 'pip'
+ cache-dependency-path: 'requirements-dev.txt'
+
+ - name: Install python packages
+ run: pip install -r requirements-dev.txt
+
+ - name: Ensure lint and pre-commit steps have been run
+ uses: pre-commit/action@v3.0.1
+
+ - name: Check files
+ uses: merl-oss-private/merl-file-check-action@v1
+ with:
+ license: ${{ env.LICENSE }}
+ full-history: ${{ env.FULL_HISTORY }} # If true, use fetch-depth 0 above
+ skip-word-presence-check: ${{ env.SKIP_WORD_PRESENCE_CHECK }}
+
+ - name: Check license compatibility
+ if: github.ref != 'refs/heads/melco'
+ uses: merl-oss-private/merl_license_compatibility_checker@v1
+ with:
+ input-filename: environment.yaml
+ license: ${{ env.LICENSE }}
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..d2a45c9
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,177 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+# vscode
+.vscode/
+
+# Experiment
+lightning_logs/
+log/
+exp*/
+evaluation/
+*.ckpt
+*.png
+*.wav
+*.pdf
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..b03c347
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,63 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+#
+# Pre-commit configuration. See https://pre-commit.com
+
+default_language_version:
+ python: python3.11
+
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.6.0
+ hooks:
+ - id: end-of-file-fixer
+ - id: trailing-whitespace
+ - id: check-yaml
+ - id: check-added-large-files
+ args: ['--maxkb=1000']
+
+ - repo: https://gitlab.com/bmares/check-json5
+ rev: v1.0.0
+ hooks:
+ - id: check-json5
+
+ - repo: https://github.com/homebysix/pre-commit-macadmin
+ rev: v1.16.2
+ hooks:
+ - id: check-git-config-email
+ args: ['--domains', 'merl.com']
+
+ - repo: https://github.com/psf/black
+ rev: 24.4.2
+ hooks:
+ - id: black
+ args:
+ - --line-length=120
+
+ - repo: https://github.com/pycqa/isort
+ rev: 5.13.2
+ hooks:
+ - id: isort
+ args: ["--profile", "black", "--filter-files", "--line-length", "120", "--skip-gitignore"]
+
+ # Uncomment to use pyupgrade (https://github.com/asottile/pyupgrade) to automatically upgrade syntax for newer python
+ # - repo: https://github.com/asottile/pyupgrade
+ # rev: v3.3.1
+ # hooks:
+ # - id: pyupgrade
+
+ # To stop flake8 error from causing a failure, use --exit-zero. By default, pre-commit will not show the warnings,
+ # so use verbose: true to see them.
+ - repo: https://github.com/pycqa/flake8
+ rev: 7.1.0
+ hooks:
+ - id: flake8
+ # Black compatibility, Eradicate options
+ args: ["--max-line-length=119", "--extend-ignore=E203", "--eradicate-whitelist-extend", "eradicate:\\s*no",
+ "--exit-zero"]
+ verbose: true
+ additional_dependencies: [
+ # https://github.com/myint/eradicate, https://github.com/wemake-services/flake8-eradicate
+ "flake8-eradicate"
+ ]
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..bf6e50a
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,8 @@
+
+# Contributing
+
+Sorry, but we do not currently accept contributions in the form of pull requests to this repository. However, you are welcome to post issues (bug reports, feature requests, questions, etc).
diff --git a/LICENSE.md b/LICENSE.md
new file mode 100644
index 0000000..cba6f6a
--- /dev/null
+++ b/LICENSE.md
@@ -0,0 +1,660 @@
+### GNU AFFERO GENERAL PUBLIC LICENSE
+
+Version 3, 19 November 2007
+
+Copyright (C) 2007 Free Software Foundation, Inc.
+
+
+Everyone is permitted to copy and distribute verbatim copies of this
+license document, but changing it is not allowed.
+
+### Preamble
+
+The GNU Affero General Public License is a free, copyleft license for
+software and other kinds of works, specifically designed to ensure
+cooperation with the community in the case of network server software.
+
+The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+our General Public Licenses are intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains
+free software for all its users.
+
+When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+Developers that use our General Public Licenses protect your rights
+with two steps: (1) assert copyright on the software, and (2) offer
+you this License which gives you legal permission to copy, distribute
+and/or modify the software.
+
+A secondary benefit of defending all users' freedom is that
+improvements made in alternate versions of the program, if they
+receive widespread use, become available for other developers to
+incorporate. Many developers of free software are heartened and
+encouraged by the resulting cooperation. However, in the case of
+software used on network servers, this result may fail to come about.
+The GNU General Public License permits making a modified version and
+letting the public access it on a server without ever releasing its
+source code to the public.
+
+The GNU Affero General Public License is designed specifically to
+ensure that, in such cases, the modified source code becomes available
+to the community. It requires the operator of a network server to
+provide the source code of the modified version running there to the
+users of that server. Therefore, public use of a modified version, on
+a publicly accessible server, gives the public access to the source
+code of the modified version.
+
+An older license, called the Affero General Public License and
+published by Affero, was designed to accomplish similar goals. This is
+a different license, not a version of the Affero GPL, but Affero has
+released a new version of the Affero GPL which permits relicensing
+under this license.
+
+The precise terms and conditions for copying, distribution and
+modification follow.
+
+### TERMS AND CONDITIONS
+
+#### 0. Definitions.
+
+"This License" refers to version 3 of the GNU Affero General Public
+License.
+
+"Copyright" also means copyright-like laws that apply to other kinds
+of works, such as semiconductor masks.
+
+"The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of
+an exact copy. The resulting work is called a "modified version" of
+the earlier work or a work "based on" the earlier work.
+
+A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user
+through a computer network, with no transfer of a copy, is not
+conveying.
+
+An interactive user interface displays "Appropriate Legal Notices" to
+the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+#### 1. Source Code.
+
+The "source code" for a work means the preferred form of the work for
+making modifications to it. "Object code" means any non-source form of
+a work.
+
+A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+The Corresponding Source need not include anything that users can
+regenerate automatically from other parts of the Corresponding Source.
+
+The Corresponding Source for a work in source code form is that same
+work.
+
+#### 2. Basic Permissions.
+
+All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+You may make, run and propagate covered works that you do not convey,
+without conditions so long as your license otherwise remains in force.
+You may convey covered works to others for the sole purpose of having
+them make modifications exclusively for you, or provide you with
+facilities for running those works, provided that you comply with the
+terms of this License in conveying all material for which you do not
+control copyright. Those thus making or running the covered works for
+you must do so exclusively on your behalf, under your direction and
+control, on terms that prohibit them from making any copies of your
+copyrighted material outside their relationship with you.
+
+Conveying under any other circumstances is permitted solely under the
+conditions stated below. Sublicensing is not allowed; section 10 makes
+it unnecessary.
+
+#### 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such
+circumvention is effected by exercising rights under this License with
+respect to the covered work, and you disclaim any intention to limit
+operation or modification of the work as a means of enforcing, against
+the work's users, your or third parties' legal rights to forbid
+circumvention of technological measures.
+
+#### 4. Conveying Verbatim Copies.
+
+You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+#### 5. Conveying Modified Source Versions.
+
+You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these
+conditions:
+
+- a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+- b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under
+ section 7. This requirement modifies the requirement in section 4
+ to "keep intact all notices".
+- c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+- d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+#### 6. Conveying Non-Source Forms.
+
+You may convey a covered work in object code form under the terms of
+sections 4 and 5, provided that you also convey the machine-readable
+Corresponding Source under the terms of this License, in one of these
+ways:
+
+- a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+- b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the Corresponding
+ Source from a network server at no charge.
+- c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+- d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+- e) Convey the object code using peer-to-peer transmission,
+ provided you inform other peers where the object code and
+ Corresponding Source of the work are being offered to the general
+ public at no charge under subsection 6d.
+
+A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal,
+family, or household purposes, or (2) anything designed or sold for
+incorporation into a dwelling. In determining whether a product is a
+consumer product, doubtful cases shall be resolved in favor of
+coverage. For a particular product received by a particular user,
+"normally used" refers to a typical or common use of that class of
+product, regardless of the status of the particular user or of the way
+in which the particular user actually uses, or expects or is expected
+to use, the product. A product is a consumer product regardless of
+whether the product has substantial commercial, industrial or
+non-consumer uses, unless such uses represent the only significant
+mode of use of the product.
+
+"Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to
+install and execute modified versions of a covered work in that User
+Product from a modified version of its Corresponding Source. The
+information must suffice to ensure that the continued functioning of
+the modified object code is in no case prevented or interfered with
+solely because modification has been made.
+
+If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or
+updates for a work that has been modified or installed by the
+recipient, or for the User Product in which it has been modified or
+installed. Access to a network may be denied when the modification
+itself materially and adversely affects the operation of the network
+or violates the rules and protocols for communication across the
+network.
+
+Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+#### 7. Additional Terms.
+
+"Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders
+of that material) supplement the terms of this License with terms:
+
+- a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+- b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+- c) Prohibiting misrepresentation of the origin of that material,
+ or requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+- d) Limiting the use for publicity purposes of names of licensors
+ or authors of the material; or
+- e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+- f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions
+ of it) with contractual assumptions of liability to the recipient,
+ for any liability that these contractual assumptions directly
+ impose on those licensors and authors.
+
+All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions; the
+above requirements apply either way.
+
+#### 8. Termination.
+
+You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+However, if you cease all violation of this License, then your license
+from a particular copyright holder is reinstated (a) provisionally,
+unless and until the copyright holder explicitly and finally
+terminates your license, and (b) permanently, if the copyright holder
+fails to notify you of the violation by some reasonable means prior to
+60 days after the cessation.
+
+Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+#### 9. Acceptance Not Required for Having Copies.
+
+You are not required to accept this License in order to receive or run
+a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+#### 10. Automatic Licensing of Downstream Recipients.
+
+Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+#### 11. Patents.
+
+A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+A contributor's "essential patent claims" are all patent claims owned
+or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+A patent license is "discriminatory" if it does not include within the
+scope of its coverage, prohibits the exercise of, or is conditioned on
+the non-exercise of one or more of the rights that are specifically
+granted under this License. You may not convey a covered work if you
+are a party to an arrangement with a third party that is in the
+business of distributing software, under which you make payment to the
+third party based on the extent of your activity of conveying the
+work, and under which the third party grants, to any of the parties
+who would receive the covered work from you, a discriminatory patent
+license (a) in connection with copies of the covered work conveyed by
+you (or copies made from those copies), or (b) primarily for and in
+connection with specific products or compilations that contain the
+covered work, unless you entered into that arrangement, or that patent
+license was granted, prior to 28 March 2007.
+
+Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+#### 12. No Surrender of Others' Freedom.
+
+If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under
+this License and any other pertinent obligations, then as a
+consequence you may not convey it at all. For example, if you agree to
+terms that obligate you to collect a royalty for further conveying
+from those to whom you convey the Program, the only way you could
+satisfy both those terms and this License would be to refrain entirely
+from conveying the Program.
+
+#### 13. Remote Network Interaction; Use with the GNU General Public License.
+
+Notwithstanding any other provision of this License, if you modify the
+Program, your modified version must prominently offer all users
+interacting with it remotely through a computer network (if your
+version supports such interaction) an opportunity to receive the
+Corresponding Source of your version by providing access to the
+Corresponding Source from a network server at no charge, through some
+standard or customary means of facilitating copying of software. This
+Corresponding Source shall include the Corresponding Source for any
+work covered by version 3 of the GNU General Public License that is
+incorporated pursuant to the following paragraph.
+
+Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the work with which it is combined will remain governed by version
+3 of the GNU General Public License.
+
+#### 14. Revised Versions of this License.
+
+The Free Software Foundation may publish revised and/or new versions
+of the GNU Affero General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+Each version is given a distinguishing version number. If the Program
+specifies that a certain numbered version of the GNU Affero General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU Affero General Public License, you may choose any version ever
+published by the Free Software Foundation.
+
+If the Program specifies that a proxy can decide which future versions
+of the GNU Affero General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+#### 15. Disclaimer of Warranty.
+
+THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT
+WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND
+PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE
+DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR
+CORRECTION.
+
+#### 16. Limitation of Liability.
+
+IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR
+CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES,
+INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES
+ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT
+NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR
+LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM
+TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER
+PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
+
+#### 17. Interpretation of Sections 15 and 16.
+
+If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+END OF TERMS AND CONDITIONS
+
+### How to Apply These Terms to Your New Programs
+
+If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these
+terms.
+
+To do so, attach the following notices to the program. It is safest to
+attach them to the start of each source file to most effectively state
+the exclusion of warranty; and each file should have at least the
+"copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as
+ published by the Free Software Foundation, either version 3 of the
+ License, or (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper
+mail.
+
+If your software can interact with users remotely through a computer
+network, you should also make sure that it provides a way for users to
+get its source. For example, if your program is a web application, its
+interface could display a "Source" link that leads users to an archive
+of the code. There are many ways you could offer source, and different
+solutions will be better for different programs; see section 13 for
+the specific requirements.
+
+You should also get your employer (if you work as a programmer) or
+school, if any, to sign a "copyright disclaimer" for the program, if
+necessary. For more information on this, and how to apply and follow
+the GNU AGPL, see .
diff --git a/LICENSES/Apache-2.0.md b/LICENSES/Apache-2.0.md
new file mode 100644
index 0000000..6afea95
--- /dev/null
+++ b/LICENSES/Apache-2.0.md
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2017 Johns Hopkins University (Shinji Watanabe)
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..bc2dbff
--- /dev/null
+++ b/README.md
@@ -0,0 +1,104 @@
+
+# Enhanced Reverberation as Supervision for Unsupervised Speech Separation
+
+This repository includes source code for training and evaluating the enhanced reverberation as supervision (ERAS), proposed in the following Interspeech 2024 paper:
+
+```
+@InProceedings{Saijo2024_eras,
+ author = {Saijo, Kohei and Wichern, Gordon and Germain, Fran\c{c}ois G. and Pan, Zexu and {Le Roux}, Jonathan},
+ title = {Enhanced Reverberation as Supervision for Unsupervised Speech Separation},
+ booktitle = {Proc. Annual Conference of International Speech Communication Association (INTERSPEECH)},
+ year = 2024,
+ month = sep
+}
+```
+
+## Table of contents
+
+1. [Installation](#installation)
+2. [How to run](#how-to-run)
+3. [Contributing](#contributing)
+4. [Copyright and license](#copyright-and-license)
+
+## Installation
+
+Clone this repo and create the anaconda environment
+
+```sh
+git clone https://github.com/merlresearch/reverberation-as-supervision
+cd reverberation-as-supervision && conda env create -f environment.yaml
+```
+
+## How to run
+
+This repository supports training on two datasets used in the paper, **WHAMR!** and **SMS-WSJ**.
+Example training configuration files are under `./configs/*dataset-name*`.
+
+Before starting training, run the following command:
+
+```sh
+conda activate ras
+```
+
+The main script for training is in `train.py`, which can be run by
+
+```sh
+python train.py --config /path/to/config --data_path /path/to/data
+```
+
+Here, `/path/to/data` is the directory containing `wav8k` and `wav16k` directories for WHAMR! and that containing `sms_wsj.json` for SMS-WSJ.
+
+As demonstrated in the paper, a best-performing model is obtained by two-stage training.
+One can first pre-train a model and then fine-tune it as follows (example commands on WHAMR!).
+
+```sh
+# Train a model with ISMS-loss weight of 0.3 for 20 epochs.
+python train.py --config ./configs/whamr/eras_whamr_isms0.3_icc0.0.yaml --data_path /path/to/whamr
+
+# Fine-tune the pre-trained model without the ISMS loss and with the ICC loss for 80 epochs.
+# Note that the pre-trained model's path has to be specified in the yaml file.
+python train.py --config ./configs/whamr/eras_whamr_isms0.0_icc0.1.yaml --data_path /path/to/whamr
+```
+
+The checkpoints and tensorboard logs are saved under `exp/eras/*config-name*` directory.
+After finishing the training, separation performance can be evaluated using `eval.py`:
+
+```sh
+python eval.py --ckpt_path /path/to/.ckpt-file --data_path /path/to/data
+```
+
+The evaluation scores are logged in the tensorboard.
+
+## Contributing
+
+See [CONTRIBUTING.md](CONTRIBUTING.md) for our policy on contributions.
+
+## Copyright and license
+
+Released under `AGPL-3.0-or-later` license, as found in the [LICENSE.md](LICENSE.md) file.
+
+All files, except as noted below:
+
+```
+Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+
+SPDX-License-Identifier: AGPL-3.0-or-later
+```
+
+The following file:
+
+- `nets/tfgridnetv2.py`
+
+was adapted from https://github.com/espnet/espnet (license included in [LICENSES/Apache-2.0.md](LICENSES/Apache-2.0.md))
+
+```
+Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+Copyright (C) 2023 ESPnet Developers
+
+SPDX-License-Identifier: AGPL-3.0-or-later
+SPDX-License-Identifier: Apache-2.0
+```
diff --git a/configs/smswsj/eras_smswsj_isms0.0_icc0.1.yaml b/configs/smswsj/eras_smswsj_isms0.0_icc0.1.yaml
new file mode 100644
index 0000000..21e5872
--- /dev/null
+++ b/configs/smswsj/eras_smswsj_isms0.0_icc0.1.yaml
@@ -0,0 +1,94 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+# Basic training config
+batch_size: 2 # batch size is doubled in dataloader in ERAS training (becomes 4 in forward path).
+val_batch_size: 4 # cv batch size
+seed: 1128 # seed for initializing training
+shuffle: true # shuffle training dataset during training
+num_workers: 4 # number of workers in dataloaders, 0 for single thread
+
+# trainer args
+trainer_conf:
+ max_epochs: 80
+ limit_train_batches: 0.5
+ limit_val_batches: 1.0
+ limit_test_batches: 1.0
+ gradient_clip_val: 1.0
+
+# early stopping configurations
+early_stopping: null
+
+# checkpoint saving
+model_checkpoint:
+ monitor: val/loss
+ save_top_k: 1
+ mode: min
+ save_weights_only: false
+ save_last: true
+
+# train/dev dataset
+dataset_name: smswsj
+dataset_conf:
+ placeholder: null # just a placeholder
+
+# torch dataset
+dataloading_conf:
+ sr: 8000
+ chunk_size: 4
+ chunking_strategy: random
+ ref_channel: &ref_channel 0
+ channel_idx: [0, 1]
+ running_eras: true
+ normalization: true
+
+stft_conf: &stft_conf
+ fft_size: &fft_size 256
+ window_length: *fft_size
+ hop_length: 64
+ window_type: sqrt_hann
+ normalize: window
+
+eras_loss_conf:
+ loss_func: complex_l1
+ past_taps: 19
+ future_taps: 1
+ ref_channel_loss_weight: 0.0
+ isms_loss_weight: 0.0
+ icc_loss_weight: 0.1
+ ref_channel: *ref_channel
+ unsupervised: true
+ supervised_loss_type: after_filtering_ref_channel
+ stft_conf: *stft_conf
+
+# Network parameters
+model_name: tfgridnetv2
+model_conf:
+ fft_size: *fft_size
+ n_srcs: 2
+ n_imics: 1
+ n_layers: 4
+ lstm_hidden_units: 256
+ attn_n_head: 4
+ attn_approx_qk_dim: 512
+ emb_dim: 48
+ emb_ks: 4
+ emb_hs: 1
+ eps: 1.0e-5
+
+# Adam optimizer and reducelronplteau scheduler
+optimizer_conf:
+ lr: 1.0e-3
+scheduler_conf:
+ patience: 3
+ factor: 0.5
+
+# fine-tuning config
+pretrained_model_path: ./exp/eras/eras_smswsj_isms0.3_icc0.0/checkpoints/last.ckpt
+warmup_steps: 4000
+
+# General parameters
+output_audio: true # save wav files of separated sources during testing in "audio_output" directory
+log_file: ""
diff --git a/configs/smswsj/eras_smswsj_isms0.3_icc0.0.yaml b/configs/smswsj/eras_smswsj_isms0.3_icc0.0.yaml
new file mode 100644
index 0000000..2095b37
--- /dev/null
+++ b/configs/smswsj/eras_smswsj_isms0.3_icc0.0.yaml
@@ -0,0 +1,92 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+# Basic training config
+batch_size: 2 # batch size is doubled in dataloader in ERAS training (becomes 4 in forward path).
+val_batch_size: 4 # cv batch size
+seed: 1128 # seed for initializing training
+shuffle: true # shuffle training dataset during training
+num_workers: 4 # number of workers in dataloaders, 0 for single thread
+
+# trainer args
+trainer_conf:
+ max_epochs: 20
+ limit_train_batches: 0.5
+ limit_val_batches: 1.0
+ limit_test_batches: 1.0
+ gradient_clip_val: 1.0
+
+# early stopping configurations
+early_stopping: null
+
+# checkpoint saving
+model_checkpoint:
+ monitor: val/loss
+ save_top_k: 1
+ mode: min
+ save_weights_only: false
+ save_last: true
+
+# train/dev dataset
+dataset_name: smswsj
+dataset_conf:
+ placeholder: null # just a placeholder
+
+# torch dataset
+dataloading_conf:
+ sr: 8000
+ chunk_size: 4
+ chunking_strategy: random
+ ref_channel: &ref_channel 0
+ channel_idx: [0, 1]
+ running_eras: true
+ normalization: true
+
+stft_conf: &stft_conf
+ fft_size: &fft_size 256
+ window_length: *fft_size
+ hop_length: 64
+ window_type: sqrt_hann
+ normalize: window
+
+eras_loss_conf:
+ loss_func: complex_l1
+ past_taps: 19
+ future_taps: 1
+ ref_channel_loss_weight: 0.0
+ isms_loss_weight: 0.3
+ icc_loss_weight: 0.0
+ ref_channel: *ref_channel
+ unsupervised: true
+ supervised_loss_type: after_filtering_ref_channel # used for validation
+ stft_conf: *stft_conf
+
+# Network parameters
+model_name: tfgridnetv2
+model_conf:
+ fft_size: *fft_size
+ n_srcs: 2
+ n_imics: 1
+ n_layers: 4
+ lstm_hidden_units: 256
+ attn_n_head: 4
+ attn_approx_qk_dim: 512
+ emb_dim: 48
+ emb_ks: 4
+ emb_hs: 1
+ eps: 1.0e-5
+
+# Adam optimizer and reducelronplteau scheduler
+optimizer_conf:
+ lr: 1.0e-3
+scheduler_conf:
+ patience: 3
+ factor: 0.5
+
+pretrained_model_path: null
+
+
+# General parameters
+log_file: ""
diff --git a/configs/smswsj/unssor_smswsj_input1chloss3ch_isms0.06_refmic0.1666.yaml b/configs/smswsj/unssor_smswsj_input1chloss3ch_isms0.06_refmic0.1666.yaml
new file mode 100644
index 0000000..53206a8
--- /dev/null
+++ b/configs/smswsj/unssor_smswsj_input1chloss3ch_isms0.06_refmic0.1666.yaml
@@ -0,0 +1,92 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+# Basic training config
+batch_size: 4 # mini-batch size
+val_batch_size: 4 # cv batch size
+seed: 1128 # seed for initializing training
+shuffle: true # shuffle training dataset during training
+num_workers: 4 # number of workers in dataloaders, 0 for single thread
+
+# trainer args
+trainer_conf:
+ max_epochs: 100
+ limit_train_batches: 1.0
+ limit_val_batches: 1.0
+ limit_test_batches: 1.0
+ gradient_clip_val: 1.0
+
+# early stopping configurations
+early_stopping: null
+
+# checkpoint saving
+model_checkpoint:
+ monitor: val/loss
+ save_top_k: 1
+ mode: min
+ save_weights_only: false
+ save_last: true
+
+# train/dev dataset
+dataset_name: smswsj
+dataset_conf:
+ placeholder: null # just a placeholder
+
+# torch dataset
+dataloading_conf:
+ sr: 8000
+ chunk_size: 4
+ chunking_strategy: random
+ ref_channel: &ref_channel 0
+ channel_idx: [0, 2, 4]
+ running_eras: false
+ normalization: true
+
+stft_conf: &stft_conf
+ fft_size: &fft_size 256
+ window_length: *fft_size
+ hop_length: 64
+ window_type: sqrt_hann
+ normalize: window
+
+eras_loss_conf:
+ loss_func: complex_l1
+ past_taps: 19
+ future_taps: 1
+ ref_channel_loss_weight: 0.166666 # 1/6
+ isms_loss_weight: 0.06
+ icc_loss_weight: 0.0
+ ref_channel: *ref_channel
+ unsupervised: true
+ supervised_loss_type: after_filtering_ref_channel # used for validation
+ stft_conf: *stft_conf
+
+# Network parameters
+model_name: tfgridnetv2
+model_conf:
+ fft_size: *fft_size
+ n_srcs: 2
+ n_imics: 1
+ n_layers: 4
+ lstm_hidden_units: 256
+ attn_n_head: 4
+ attn_approx_qk_dim: 512
+ emb_dim: 48
+ emb_ks: 4
+ emb_hs: 1
+ eps: 1.0e-5
+
+# Adam optimizer and reducelronplteau scheduler
+optimizer_conf:
+ lr: 1.0e-3
+scheduler_conf:
+ patience: 2
+ factor: 0.5
+
+pretrained_model_path: null
+
+
+# General parameters
+log_file: ""
diff --git a/configs/smswsj/unssor_smswsj_input1chloss6ch_isms0.02_refmic0.2.yaml b/configs/smswsj/unssor_smswsj_input1chloss6ch_isms0.02_refmic0.2.yaml
new file mode 100644
index 0000000..3218faf
--- /dev/null
+++ b/configs/smswsj/unssor_smswsj_input1chloss6ch_isms0.02_refmic0.2.yaml
@@ -0,0 +1,92 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+# Basic training config
+batch_size: 4 # mini-batch size
+val_batch_size: 4 # cv batch size
+seed: 1128 # seed for initializing training
+shuffle: true # shuffle training dataset during training
+num_workers: 4 # number of workers in dataloaders, 0 for single thread
+
+# trainer args
+trainer_conf:
+ max_epochs: 100
+ limit_train_batches: 1.0
+ limit_val_batches: 1.0
+ limit_test_batches: 1.0
+ gradient_clip_val: 1.0
+
+# early stopping configurations
+early_stopping: null
+
+# checkpoint saving
+model_checkpoint:
+ monitor: val/loss
+ save_top_k: 1
+ mode: min
+ save_weights_only: false
+ save_last: true
+
+# train/dev dataset
+dataset_name: smswsj
+dataset_conf:
+ placeholder: null # just a placeholder
+
+# torch dataset
+dataloading_conf:
+ sr: 8000
+ chunk_size: 4
+ chunking_strategy: random
+ ref_channel: &ref_channel 0
+ channel_idx: [0, 1, 2, 3, 4, 5]
+ running_eras: false
+ normalization: true
+
+stft_conf: &stft_conf
+ fft_size: &fft_size 256
+ window_length: *fft_size
+ hop_length: 64
+ window_type: sqrt_hann
+ normalize: window
+
+eras_loss_conf:
+ loss_func: complex_l1
+ past_taps: 19
+ future_taps: 1
+ ref_channel_loss_weight: 0.2
+ isms_loss_weight: 0.02
+ icc_loss_weight: 0.0
+ ref_channel: *ref_channel
+ unsupervised: true
+ supervised_loss_type: after_filtering_ref_channel # used for validation
+ stft_conf: *stft_conf
+
+# Network parameters
+model_name: tfgridnetv2
+model_conf:
+ fft_size: *fft_size
+ n_srcs: 2
+ n_imics: 1
+ n_layers: 4
+ lstm_hidden_units: 256
+ attn_n_head: 4
+ attn_approx_qk_dim: 512
+ emb_dim: 48
+ emb_ks: 4
+ emb_hs: 1
+ eps: 1.0e-5
+
+# Adam optimizer and reducelronplteau scheduler
+optimizer_conf:
+ lr: 1.0e-3
+scheduler_conf:
+ patience: 2
+ factor: 0.5
+
+pretrained_model_path: null
+
+
+# General parameters
+log_file: ""
diff --git a/configs/smswsj/unssor_smswsj_input3chloss3ch_isms0.06_refmic0.1666.yaml b/configs/smswsj/unssor_smswsj_input3chloss3ch_isms0.06_refmic0.1666.yaml
new file mode 100644
index 0000000..c64524e
--- /dev/null
+++ b/configs/smswsj/unssor_smswsj_input3chloss3ch_isms0.06_refmic0.1666.yaml
@@ -0,0 +1,92 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+# Basic training config
+batch_size: 4 # mini-batch size
+val_batch_size: 4 # cv batch size
+seed: 1128 # seed for initializing training
+shuffle: true # shuffle training dataset during training
+num_workers: 4 # number of workers in dataloaders, 0 for single thread
+
+# trainer args
+trainer_conf:
+ max_epochs: 100
+ limit_train_batches: 1.0
+ limit_val_batches: 1.0
+ limit_test_batches: 1.0
+ gradient_clip_val: 1.0
+
+# early stopping configurations
+early_stopping: null
+
+# checkpoint saving
+model_checkpoint:
+ monitor: val/loss
+ save_top_k: 1
+ mode: min
+ save_weights_only: false
+ save_last: true
+
+# train/dev dataset
+dataset_name: smswsj
+dataset_conf:
+ placeholder: null # just a placeholder
+
+# torch dataset
+dataloading_conf:
+ sr: 8000
+ chunk_size: 4
+ chunking_strategy: random
+ ref_channel: &ref_channel 0
+ channel_idx: [0, 2, 4]
+ running_eras: false
+ normalization: true
+
+stft_conf: &stft_conf
+ fft_size: &fft_size 256
+ window_length: *fft_size
+ hop_length: 64
+ window_type: sqrt_hann
+ normalize: window
+
+eras_loss_conf:
+ loss_func: complex_l1
+ past_taps: 19
+ future_taps: 1
+ ref_channel_loss_weight: 0.166666 # 1/6
+ isms_loss_weight: 0.06
+ icc_loss_weight: 0.0
+ ref_channel: *ref_channel
+ unsupervised: true
+ supervised_loss_type: after_filtering_ref_channel # used for validation
+ stft_conf: *stft_conf
+
+# Network parameters
+model_name: tfgridnetv2
+model_conf:
+ fft_size: *fft_size
+ n_srcs: 2
+ n_imics: 3
+ n_layers: 4
+ lstm_hidden_units: 256
+ attn_n_head: 4
+ attn_approx_qk_dim: 512
+ emb_dim: 48
+ emb_ks: 4
+ emb_hs: 1
+ eps: 1.0e-5
+
+# Adam optimizer and reducelronplteau scheduler
+optimizer_conf:
+ lr: 1.0e-3
+scheduler_conf:
+ patience: 2
+ factor: 0.5
+
+pretrained_model_path: null
+
+
+# General parameters
+log_file: ""
diff --git a/configs/smswsj/unssor_smswsj_input6chloss6ch_isms0.02_refmic0.2.yaml b/configs/smswsj/unssor_smswsj_input6chloss6ch_isms0.02_refmic0.2.yaml
new file mode 100644
index 0000000..5893655
--- /dev/null
+++ b/configs/smswsj/unssor_smswsj_input6chloss6ch_isms0.02_refmic0.2.yaml
@@ -0,0 +1,92 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+# Basic training config
+batch_size: 4 # mini-batch size
+val_batch_size: 4 # cv batch size
+seed: 1128 # seed for initializing training
+shuffle: true # shuffle training dataset during training
+num_workers: 4 # number of workers in dataloaders, 0 for single thread
+
+# trainer args
+trainer_conf:
+ max_epochs: 100
+ limit_train_batches: 1.0
+ limit_val_batches: 1.0
+ limit_test_batches: 1.0
+ gradient_clip_val: 1.0
+
+# early stopping configurations
+early_stopping: null
+
+# checkpoint saving
+model_checkpoint:
+ monitor: val/loss
+ save_top_k: 1
+ mode: min
+ save_weights_only: false
+ save_last: true
+
+# train/dev dataset
+dataset_name: smswsj
+dataset_conf:
+ placeholder: null # just a placeholder
+
+# torch dataset
+dataloading_conf:
+ sr: 8000
+ chunk_size: 4
+ chunking_strategy: random
+ ref_channel: &ref_channel 0
+ channel_idx: [0, 1, 2, 3, 4, 5]
+ running_eras: false
+ normalization: true
+
+stft_conf: &stft_conf
+ fft_size: &fft_size 256
+ window_length: *fft_size
+ hop_length: 64
+ window_type: sqrt_hann
+ normalize: window
+
+eras_loss_conf:
+ loss_func: complex_l1
+ past_taps: 19
+ future_taps: 1
+ ref_channel_loss_weight: 0.2
+ isms_loss_weight: 0.02
+ icc_loss_weight: 0.0
+ ref_channel: *ref_channel
+ unsupervised: true
+ supervised_loss_type: after_filtering_ref_channel # used for validation
+ stft_conf: *stft_conf
+
+# Network parameters
+model_name: tfgridnetv2
+model_conf:
+ fft_size: *fft_size
+ n_srcs: 2
+ n_imics: 6
+ n_layers: 4
+ lstm_hidden_units: 256
+ attn_n_head: 4
+ attn_approx_qk_dim: 512
+ emb_dim: 48
+ emb_ks: 4
+ emb_hs: 1
+ eps: 1.0e-5
+
+# Adam optimizer and reducelronplteau scheduler
+optimizer_conf:
+ lr: 1.0e-3
+scheduler_conf:
+ patience: 2
+ factor: 0.5
+
+pretrained_model_path: null
+
+
+# General parameters
+log_file: ""
diff --git a/configs/whamr/eras_whamr_isms0.0_icc0.1.yaml b/configs/whamr/eras_whamr_isms0.0_icc0.1.yaml
new file mode 100644
index 0000000..c0ea0ec
--- /dev/null
+++ b/configs/whamr/eras_whamr_isms0.0_icc0.1.yaml
@@ -0,0 +1,96 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+# Basic training config
+batch_size: 2 # batch size is doubled in dataloader in ERAS training (becomes 4 in forward path).
+val_batch_size: 4 # cv batch size
+seed: 1128 # seed for initializing training
+shuffle: true # shuffle training dataset during training
+num_workers: 4 # number of workers in dataloaders, 0 for single thread
+
+# trainer args
+trainer_conf:
+ max_epochs: 80
+ limit_train_batches: 0.5
+ limit_val_batches: 1.0
+ limit_test_batches: 1.0
+ gradient_clip_val: 1.0
+
+# early stopping configurations
+early_stopping: null
+
+# checkpoint saving
+model_checkpoint:
+ monitor: val/loss
+ save_top_k: 1
+ mode: min
+ save_weights_only: false
+ save_last: true
+
+# train/dev dataset
+dataset_name: whamr
+dataset_conf:
+ task: sep_reverb_all_srcs
+ use_min: true
+
+# torch dataset
+dataloading_conf:
+ sr: 8000
+ chunk_size: 4
+ chunking_strategy: random
+ ref_channel: &ref_channel 0
+ channel_idx: [0, 1]
+ running_eras: true
+ normalization: true
+
+stft_conf: &stft_conf
+ fft_size: &fft_size 256
+ window_length: *fft_size
+ hop_length: 64
+ window_type: sqrt_hann
+ normalize: window
+
+eras_loss_conf:
+ loss_func: complex_l1
+ past_taps: 19
+ future_taps: 1
+ ref_channel_loss_weight: 0.0
+ isms_loss_weight: 0.0
+ icc_loss_weight: 0.1
+ ref_channel: *ref_channel
+ unsupervised: true
+ supervised_loss_type: after_filtering_ref_channel
+ stft_conf: *stft_conf
+
+# Network parameters
+model_name: tfgridnetv2
+model_conf:
+ fft_size: *fft_size
+ n_srcs: 2
+ n_imics: 1
+ n_layers: 4
+ lstm_hidden_units: 256
+ attn_n_head: 4
+ attn_approx_qk_dim: 512
+ emb_dim: 48
+ emb_ks: 4
+ emb_hs: 1
+ eps: 1.0e-5
+
+# Adam optimizer and reducelronplteau scheduler
+optimizer_conf:
+ lr: 1.0e-3
+scheduler_conf:
+ patience: 3
+ factor: 0.5
+
+# fine-tuning config
+pretrained_model_path: ./exp/eras/eras_whamr_isms0.3_icc0.0/checkpoints/last.ckpt
+warmup_steps: 4000
+
+
+# General parameters
+output_audio: true # save wav files of separated sources during testing in "audio_output" directory
+log_file: ""
diff --git a/configs/whamr/eras_whamr_isms0.3_icc0.0.yaml b/configs/whamr/eras_whamr_isms0.3_icc0.0.yaml
new file mode 100644
index 0000000..36f3b44
--- /dev/null
+++ b/configs/whamr/eras_whamr_isms0.3_icc0.0.yaml
@@ -0,0 +1,94 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+# Basic training config
+batch_size: 2 # batch size is doubled in dataloader in ERAS training (becomes 4 in forward path).
+val_batch_size: 4 # cv batch size
+seed: 1128 # seed for initializing training
+shuffle: true # shuffle training dataset during training
+num_workers: 4 # number of workers in dataloaders, 0 for single thread
+
+# trainer args
+trainer_conf:
+ max_epochs: 20
+ limit_train_batches: 0.5
+ limit_val_batches: 1.0
+ limit_test_batches: 1.0
+ gradient_clip_val: 1.0
+
+# early stopping configurations
+early_stopping: null
+
+# checkpoint saving
+model_checkpoint:
+ monitor: val/loss
+ save_top_k: 1
+ mode: min
+ save_weights_only: false
+ save_last: true
+
+# train/dev dataset
+dataset_name: whamr
+dataset_conf:
+ task: sep_reverb_all_srcs
+ use_min: true
+
+# torch dataset
+dataloading_conf:
+ sr: 8000
+ chunk_size: 4
+ chunking_strategy: random
+ ref_channel: &ref_channel 0
+ channel_idx: [0, 1]
+ running_eras: true
+ normalization: true
+
+stft_conf: &stft_conf
+ fft_size: &fft_size 256
+ window_length: *fft_size
+ hop_length: 64
+ window_type: sqrt_hann
+ normalize: window
+
+eras_loss_conf:
+ loss_func: complex_l1
+ past_taps: 19
+ future_taps: 1
+ ref_channel_loss_weight: 0.0
+ isms_loss_weight: 0.3
+ icc_loss_weight: 0.0
+ ref_channel: *ref_channel
+ unsupervised: true
+ supervised_loss_type: after_filtering_ref_channel # used for validation
+ stft_conf: *stft_conf
+
+# Network parameters
+model_name: tfgridnetv2
+model_conf:
+ fft_size: *fft_size
+ n_srcs: 2
+ n_imics: 1
+ n_layers: 4
+ lstm_hidden_units: 256
+ attn_n_head: 4
+ attn_approx_qk_dim: 512
+ emb_dim: 48
+ emb_ks: 4
+ emb_hs: 1
+ eps: 1.0e-5
+
+# Adam optimizer and reducelronplteau scheduler
+optimizer_conf:
+ lr: 1.0e-3
+scheduler_conf:
+ patience: 3
+ factor: 0.5
+
+# fine-tuning config
+pretrained_model_path: null
+
+
+# General parameters
+log_file: ""
diff --git a/datasets/__init__.py b/datasets/__init__.py
new file mode 100644
index 0000000..8fccc03
--- /dev/null
+++ b/datasets/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
diff --git a/datasets/dataset_creator.py b/datasets/dataset_creator.py
new file mode 100644
index 0000000..75a7a3c
--- /dev/null
+++ b/datasets/dataset_creator.py
@@ -0,0 +1,34 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+from .paths.smswsj import get_smswsj_paths
+from .paths.whamr import get_whamr_paths
+from .stft_dataset import STFTDataset
+
+
+def dataset_creator(hparams, data_path, partition):
+ path_list = datapath_creator(hparams.dataset_name, data_path, partition, hparams.dataset_conf)
+
+ # some setups specific for training stage
+ is_training = partition == "tr"
+
+ dataset = STFTDataset(
+ path_list,
+ is_training,
+ hparams.stft_conf,
+ **hparams.dataloading_conf,
+ )
+ return dataset
+
+
+def datapath_creator(dataset_name, data_path, partition, dataset_conf):
+ dataset_conf["partition"] = partition
+ if dataset_name == "whamr":
+ path_list = get_whamr_paths(data_path, **dataset_conf)
+ elif dataset_name == "smswsj":
+ path_list = get_smswsj_paths(data_path, **dataset_conf)
+ else:
+ raise ValueError("Dataset {} not currently supported.".format(dataset_name))
+ return path_list
diff --git a/datasets/paths/__init__.py b/datasets/paths/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/datasets/paths/smswsj.py b/datasets/paths/smswsj.py
new file mode 100644
index 0000000..f29eaf4
--- /dev/null
+++ b/datasets/paths/smswsj.py
@@ -0,0 +1,55 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+import json
+
+PARTITION_MAP = {
+ "tr": {"name": "train_si284", "num_data": 33561},
+ "cv": {"name": "cv_dev93", "num_data": 982},
+ "tt": {"name": "test_eval92", "num_data": 1332},
+}
+
+
+def get_smswsj_paths(
+ data_path,
+ partition="tt",
+ **kwargs,
+):
+ # rename partition
+ stage = PARTITION_MAP[partition]["name"]
+
+ # load metadata of smswsj
+ with open(data_path / "sms_wsj.json") as f:
+ mixinfo = json.load(f)
+ mixinfo = mixinfo["datasets"][stage]
+
+ pathlist = []
+ for key, info in mixinfo.items():
+ # information of a sample
+ tmp = {
+ "id": key,
+ "mix": info["audio_path"]["observation"],
+ "srcs": {
+ "reverb1": info["audio_path"]["speech_image"][0],
+ "reverb2": info["audio_path"]["speech_image"][1],
+ "anechoic1": info["audio_path"]["speech_reverberation_early"][0],
+ "anechoic2": info["audio_path"]["speech_reverberation_early"][1],
+ "dry1": info["audio_path"]["speech_source"][0],
+ "dry2": info["audio_path"]["speech_source"][1],
+ },
+ "num_samples": info["num_samples"]["speech_source"],
+ "offset": info["offset"],
+ }
+ # add to pathlist
+ pathlist.append(tmp)
+ # check number of data
+ assert len(pathlist) == PARTITION_MAP[partition]["num_data"]
+
+ if partition == "cv":
+ # sort by length for efficient validation
+ # decreasing padded-zeros would lead to acccurate validation
+ pathlist = sorted(pathlist, key=lambda x: x["num_samples"])
+
+ return pathlist
diff --git a/datasets/paths/whamr.py b/datasets/paths/whamr.py
new file mode 100644
index 0000000..55fe2d0
--- /dev/null
+++ b/datasets/paths/whamr.py
@@ -0,0 +1,122 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+import os
+from collections import OrderedDict
+
+
+def get_whamr_paths(
+ data_path,
+ task="sep_noisy",
+ partition="tt",
+ sr=8000,
+ use_min=True,
+):
+ POSSIBLE_TASKS = [
+ "sep_clean_reverb",
+ "denoise_both_reverb",
+ "sep_reverb_all_srcs",
+ "sep_noisy_reverb_all_srcs",
+ ]
+
+ assert task in POSSIBLE_TASKS, task
+ S1_ANECHOIC_DIR = "s1_anechoic"
+ S2_ANECHOIC_DIR = "s2_anechoic"
+ S1_REVERB_DIR = "s1_reverb"
+ S2_REVERB_DIR = "s2_reverb"
+ S1_DRY = "s1"
+ S2_DRY = "s2"
+ BOTH_REVERB_DIR = "mix_both_reverb"
+ CLEAN_REVERB_DIR = "mix_clean_reverb"
+
+ if task == "sep_reverb_all_srcs":
+ mix_dir = CLEAN_REVERB_DIR
+ src_dir_list = [
+ S1_ANECHOIC_DIR,
+ S2_ANECHOIC_DIR,
+ S1_REVERB_DIR,
+ S2_REVERB_DIR,
+ S1_DRY,
+ S2_DRY,
+ ]
+ elif task == "sep_noisy_reverb_all_srcs":
+ mix_dir = BOTH_REVERB_DIR
+ src_dir_list = [
+ S1_ANECHOIC_DIR,
+ S2_ANECHOIC_DIR,
+ S1_REVERB_DIR,
+ S2_REVERB_DIR,
+ S1_DRY,
+ S2_DRY,
+ ]
+
+ else:
+ raise ValueError("WHAMR task {} not available, please choose from {}".format(task, POSSIBLE_TASKS))
+ return get_wsj2mix_paths(
+ data_path,
+ partition=partition,
+ sr=sr,
+ use_min=use_min,
+ mix_dir=mix_dir,
+ src_dir_list=src_dir_list,
+ )
+
+
+def get_wsj2mix_paths(
+ data_path,
+ partition="tt",
+ sr=8000,
+ use_min=True,
+ mix_dir=None,
+ src_dir_list=None,
+):
+ if mix_dir is None:
+ mix_dir = "mix"
+ if src_dir_list is None:
+ src_dir_list = ["s1", "s2"]
+ if sr == 8000:
+ wav_dir = "wav8k"
+ elif sr == 16000:
+ wav_dir = "wav16k"
+ else:
+ raise ValueError("set wsj0-2mix dataset sample rate to either 8kHz or 16kHz")
+
+ if use_min:
+ max_or_min_dir = "min"
+ else:
+ max_or_min_dir = "max"
+
+ root_path = os.path.join(data_path, wav_dir, max_or_min_dir, partition)
+ filelist = [f for f in os.listdir(os.path.join(root_path, src_dir_list[0])) if f.endswith(".wav")]
+ if partition == "tr":
+ assert len(filelist) == 20000, "Expected 20000 files in training set"
+ elif partition == "cv":
+ assert len(filelist) == 5000, "Expected 5000 files in validation set"
+ elif partition == "tt":
+ assert len(filelist) == 3000, "Expected 3000 files in testing set"
+ path_list = get_path_list(filelist, root_path, mix_dir, src_dir_list)
+
+ return path_list
+
+
+def get_path_dict(filename, root_dir, mix_dir, src_dir_list):
+ id_ = os.path.splitext(filename)[0]
+ path_dict = {
+ "id": id_,
+ "srcs": OrderedDict({s: os.path.join(root_dir, s, filename) for s in src_dir_list}),
+ "mix": os.path.join(root_dir, mix_dir, filename),
+ }
+ return path_dict
+
+
+def get_path_list(filelist, root_dir, mix_dir, src_dir_list):
+ # if type(filelist) == type([]):
+ if isinstance(filelist, list):
+ path_list = [get_path_dict(f, root_dir, mix_dir, src_dir_list) for f in filelist]
+ if isinstance(filelist, dict):
+ path_list = {}
+ for key in filelist.keys():
+ path_list[key] = [get_path_dict(f, root_dir, mix_dir, src_dir_list) for f in filelist[key]]
+ return path_list
diff --git a/datasets/stft_dataset.py b/datasets/stft_dataset.py
new file mode 100755
index 0000000..98edfed
--- /dev/null
+++ b/datasets/stft_dataset.py
@@ -0,0 +1,279 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+from typing import Dict, List, Union
+
+import numpy as np
+import soundfile as sf
+import torch
+
+from utils.audio_utils import do_stft, rotate_channels
+
+
+class STFTDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ path_list: List,
+ training: bool,
+ stft_conf: Dict,
+ sr: int = 8000,
+ chunk_size: int = None,
+ chunking_strategy: str = "random",
+ ref_channel: int = 0,
+ channel_idx: Union[int, List[int]] = None,
+ running_eras: bool = False,
+ normalization: bool = False,
+ ):
+ """Dataset class.
+
+ Parameters
+ ----------
+ path_list: List[Dict]
+ List of the metadata of each audio data.
+ training: bool
+ Whether or not this dataset is initialized for training or not.
+ stft_conf: Dict
+ STFT configuration.
+ sr: int
+ Sampling rate.
+ chunk_size: int or float
+ Input length in seconds during training.
+ chunking_strategy: str
+ How to make a training chunk.
+ Choices are random, partial_overlap, or full_overlap.
+ NOTE: this is considered only for SMS-WSJ.
+ ref_channel: int
+ Reference channel index.
+ channel_idx: int or List[int]
+ Channel index (indices) of the input data.
+ To return multi-channel sources, specify the list.
+ running_eras: bool
+ Whether the training algorithm is ERAS or not.
+ ERAS needs a batch which includes both left and right channels in the same mini-batch
+ for the ICC loss (i.e., batch: [data1_L, data1_R, ..., dataN_L, dataN_R]).
+ normalization: bool
+ Whether to apply the variance normalization.
+ """
+
+ self.training = training
+ self.running_eras = running_eras and self.training
+ if self.running_eras:
+ print(
+ f"Use {channel_idx}-th channel(s) as reference microphone(s)",
+ flush=True,
+ )
+
+ self.path_list = path_list
+ self.sr = sr
+
+ self.chunk_size = int(chunk_size * sr) if self.training else None
+
+ self.ref_channel = ref_channel
+
+ self.stft_conf = stft_conf
+ self.channel_idx = channel_idx
+ self.normalization = normalization
+
+ assert chunking_strategy in ["random", "partial_overlap", "full_overlap"]
+ self.chunking_stragegy = chunking_strategy
+
+ def __len__(self):
+ return len(self.path_list)
+
+ def __getitem__(self, index):
+ y_mix, y_srcs = self._read_audio(index)
+ y_mix = torch.from_numpy(y_mix)
+ if isinstance(y_srcs, dict):
+ y_srcs = {tag: torch.from_numpy(y_srcs[tag]) for tag in y_srcs.keys()}
+ elif y_srcs is not None:
+ y_srcs = torch.from_numpy(y_srcs)
+ mix_stft = self._stft(y_mix.T) # (frame, freq, n_chan)
+
+ srcs = dict(y_mix=y_mix, y_srcs=y_srcs)
+ if self.running_eras:
+ mix_stft, srcs = self._rotate_and_stack_channels_stft(mix_stft, srcs)
+ srcs["y_mix_stft"] = mix_stft
+ return mix_stft, srcs
+
+ def _read_audio(self, index, start_samp=0, end_samp=None):
+ path_dict = self.path_list[index % len(self.path_list)]
+ mix_path = path_dict["mix"]
+ n_frames = sf.info(mix_path).frames
+
+ # chunking, for SMS-WSJ
+ if self.chunk_size is None or n_frames <= self.chunk_size:
+ if self.chunking_stragegy == "full_overlap":
+ offset = max(path_dict["offset"]) # start of shorter utterance
+ min_length = min(path_dict["num_samples"]) # length of shorter utterance
+ start_samp, end_samp = offset, min_length + offset
+ else:
+ start_samp = 0
+ end_samp = None
+ else: # n_frames > self.frame_size:
+ # in some partially-overlapped data we need to find fully-overlapped segment
+ # now we assume number of sources is two
+ if self.chunking_stragegy == "full_overlap":
+ assert "num_samples" in path_dict and "offset" in path_dict
+ offset = max(path_dict["offset"]) # start of shorter utterance
+ min_length = min(path_dict["num_samples"]) # length of shorter utterance
+
+ if min_length <= self.chunk_size:
+ start_samp, end_samp = offset, min_length + offset
+ else:
+ start_samp = np.random.randint(offset, min_length + offset - self.chunk_size)
+ end_samp = start_samp + self.chunk_size
+
+ elif self.chunking_stragegy == "partial_overlap":
+ start_samp, end_samp = self._random_start_and_end(path_dict, min_overlap=self.chunk_size // 4)
+
+ elif self.chunking_stragegy == "random":
+ start_samp = np.random.randint(0, n_frames - self.chunk_size)
+ end_samp = start_samp + self.chunk_size
+
+ else:
+ raise NotImplementedError(self.chunking_stragegy)
+
+ y_mix = self._read_wav(mix_path, start_samp, end_samp)
+ y_srcs = {"reverb": []}
+ for tag, src_path in path_dict["srcs"].items():
+ if "reverb" in tag:
+ y_srcs["reverb"].append(self._read_wav(src_path, start_samp, end_samp))
+ assert len(y_srcs["reverb"]) > 0, ("reverb", path_dict.keys())
+ y_srcs["reverb"] = np.stack(y_srcs["reverb"], axis=-1)
+
+ # normalization:
+ if self.normalization:
+ y_mix, y_srcs = self._normalization(y_mix, y_srcs)
+
+ return y_mix, y_srcs
+
+ def _read_wav(self, path, start_samp=0, end_samp=None):
+ y, sr = sf.read(path, start=start_samp, stop=end_samp, dtype=np.float32)
+ assert sr == self.sr, f"samplerate of data {sr} does not match requested samplerate {self.sr}"
+ assert y.ndim == 2 and y.shape[-1] > 1, f"audios must be multi-channel but the shape of {path} is {y.shape}"
+
+ if self.channel_idx is not None:
+ return y[..., self.channel_idx]
+ else:
+ return y
+
+ def _random_start_and_end(self, path_dict, min_overlap=None):
+ if min_overlap is None:
+ min_overlap = self.chunk_size // 5
+
+ assert "num_samples" in path_dict and "offset" in path_dict
+ offset = max(path_dict["offset"]) # start of shorter utterance
+ min_length = min(path_dict["num_samples"]) # length of shorter utterance
+ max_length = max(path_dict["num_samples"]) # mixture length
+
+ # when mixture is shorter than the chunk size we want,
+ # we simply use the entire utterance
+ if max_length < self.chunk_size:
+ return 0, max_length
+
+ # else, we randomly choose partially-overlapped segment
+ # where at least "min_overlap"-length overlap exists
+ left = offset
+ right = left + min_length
+ min_overlap = min(min_length // 2 - 1, min_overlap)
+
+ assert left + min_overlap < right - min_overlap, (left, right, min_overlap)
+ if np.random.random() > 0.5:
+ start = np.random.randint(left + min_overlap, right - min_overlap)
+ end = start + self.chunk_size
+ if end > max_length:
+ start -= end - max_length
+ end = max_length
+ else:
+ end = np.random.randint(left + min_overlap, right - min_overlap)
+ start = end - self.chunk_size
+ if start < 0:
+ start = 0
+ end = self.chunk_size
+ return start, end
+
+ def _normalization(self, y_mix, y_srcs):
+ # variance normalization
+ mean = y_mix.mean(keepdims=True)
+ std = y_mix.std(keepdims=True)
+
+ y_mix = (y_mix - mean) / std
+ if isinstance(y_srcs, dict):
+ y_srcs["reverb"] = (y_srcs["reverb"] - mean[None]) / std[None]
+ elif y_srcs is not None:
+ y_srcs = (y_srcs - mean[None]) / std[None]
+
+ return y_mix, y_srcs
+
+ def _stft(self, samples):
+ assert samples.ndim < 4
+ # if there is channel dim
+ if samples.ndim == 3:
+ n_chan, n_src, n_samples = samples.shape
+ samples = samples.reshape(-1, n_samples)
+ reshaped = True
+ else:
+ reshaped = False
+ X = do_stft(samples, **self.stft_conf)
+ if reshaped:
+ X = X.reshape(X.shape[:2] + (n_chan, n_src))
+ return X
+
+ def _rotate_and_stack_channels_stft(self, feat, srcs):
+ """Rotate the channel order and stack.
+ Suppose the input `feat` is (n_frame, n_freq, n_chan) and n_chan==2.
+ This function stacks `feat` with the channel order of [0, 1] and [1, 0]
+ and returns (n_frame, n_freq, n_chan, n_chan), which is used in ERAS.
+
+ Returned `feat` is stacked along the batch dim in the `collate_seq_eras` funciton.
+ """
+
+ n_chan = len(self.channel_idx) # number of microphone channels
+ assert feat.shape[-1] == n_chan, feat.shape
+
+ feat_stack = []
+ for ref_channel in range(n_chan):
+ feat_tmp = rotate_channels(feat, ref_channel, channel_dim=-1)
+ feat_stack.append(feat_tmp)
+ feat_stack = torch.stack(feat_stack, dim=-1)
+
+ # first initialize the dict to make code cleaner
+ srcs_stack = {}
+ for src in srcs:
+ if isinstance(srcs[src], dict) and src == "y_srcs":
+ srcs_stack[src] = {}
+ for key in srcs[src]:
+ srcs_stack[src][key] = []
+ else:
+ srcs_stack[src] = []
+
+ for ref_channel in range(n_chan):
+ for src in srcs:
+ if isinstance(srcs[src], dict) and src == "y_srcs":
+ assert srcs[src]["reverb"].shape[-2] == n_chan, (
+ src,
+ key,
+ srcs[src]["reverb"].shape,
+ )
+ tmp = rotate_channels(srcs[src]["reverb"], ref_channel, channel_dim=-2)
+ srcs_stack[src]["reverb"].append(tmp)
+
+ # process torch Tensor
+ elif srcs[src] is not None:
+ assert srcs[src].shape[-1] == n_chan, (
+ src,
+ srcs[src].shape,
+ )
+ tmp = rotate_channels(srcs[src], ref_channel, channel_dim=-1)
+ srcs_stack[src].append(tmp)
+
+ # stack along batch dim
+ for src in srcs_stack:
+ if isinstance(srcs[src], dict) and src == "y_srcs":
+ srcs_stack[src]["reverb"] = torch.stack(srcs_stack[src]["reverb"], dim=-1)
+ elif srcs[src] is not None:
+ srcs_stack[src] = torch.stack(srcs_stack[src], dim=-1)
+
+ return feat_stack, srcs_stack
diff --git a/environment.yaml b/environment.yaml
new file mode 100644
index 0000000..2debfe1
--- /dev/null
+++ b/environment.yaml
@@ -0,0 +1,25 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+name: ras
+channels:
+ - pytorch
+ - conda-forge
+ - nvidia
+dependencies:
+ - python=3.11.5
+ - pytorch=2.1.1
+ - pytorch-cuda=11.8
+ - pip
+ - pip:
+ - pytorch-lightning==2.1.2
+ - loguru
+ - pyyaml
+ - SoundFile==0.10.3.post1
+ - tensorboard
+ - protobuf==3.20.3
+ - fast-bss-eval
+ - pesq==0.0.4
+ - numpy==1.26.4
diff --git a/eval.py b/eval.py
new file mode 100644
index 0000000..0d05405
--- /dev/null
+++ b/eval.py
@@ -0,0 +1,50 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+import argparse
+from pathlib import Path
+
+import loguru
+import torch
+from pytorch_lightning import Trainer, seed_everything
+from pytorch_lightning.loggers import TensorBoardLogger
+
+from lightning_train import RASTrainingModule
+from utils.config import yaml_to_parser
+
+
+def main(args):
+
+ config_path = args.ckpt_path.parent.parent / "hparams.yaml"
+ hparams = yaml_to_parser(config_path)
+ hparams = hparams.parse_args([])
+
+ seed_everything(0, workers=True)
+ torch.backends.cuda.matmul.allow_tf32 = False
+ torch.backends.cudnn.allow_tf32 = True
+
+ exp_name, name, save_dir = [config_path.parents[i].name for i in range(3)]
+ logger = TensorBoardLogger(save_dir=save_dir, name=name, version=exp_name)
+
+ trainer = Trainer(
+ logger=logger,
+ enable_progress_bar=True,
+ deterministic=True,
+ devices=1,
+ num_nodes=1,
+ )
+ # testing
+ loguru.logger.info("Begin Testing")
+ model = RASTrainingModule.load_from_checkpoint(args.ckpt_path, hparams=hparams, data_path=args.data_path)
+ trainer.test(model)
+ loguru.logger.info("Testing complete")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ckpt_path", type=Path, required=True)
+ parser.add_argument("--data_path", type=Path, required=True)
+ args, other_options = parser.parse_known_args()
+ main(args)
diff --git a/lightning_train.py b/lightning_train.py
new file mode 100644
index 0000000..60cc4fb
--- /dev/null
+++ b/lightning_train.py
@@ -0,0 +1,204 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+import os
+import random
+from argparse import Namespace
+
+import fast_bss_eval
+import numpy as np
+import pytorch_lightning as pl
+import torch
+import torch.utils.data as data
+from loguru import logger
+from pesq import pesq
+from pytorch_lightning.utilities import rank_zero_only
+from torch import optim
+
+from datasets.dataset_creator import dataset_creator
+from loss_functions.ras_loss import RASLoss
+from nets.build_model import build_model
+from utils.audio_utils import istft_4dim
+from utils.collate import collate_seq, collate_seq_eras
+
+
+class RASTrainingModule(pl.LightningModule):
+ def __init__(self, hparams, data_path):
+ super().__init__()
+
+ if not isinstance(hparams, Namespace):
+ hparams = Namespace(hparams.model_name, **hparams.model_conf)
+ self.data_path = data_path
+
+ self.save_hyperparameters(hparams)
+ self.model = build_model(hparams.model_name, hparams.model_conf)
+ self.loss = RASLoss(**hparams.eras_loss_conf)
+
+ self.current_step = 0 # used for learning-rate warmup
+
+ def load_pretrained_weight(self):
+ if self.hparams.pretrained_model_path is not None:
+ if torch.cuda.is_available():
+ state_dict = torch.load(self.hparams.pretrained_model_path)
+ else:
+ state_dict = torch.load(self.hparams.pretrained_model_path, map_location=torch.device("cpu"))
+ try:
+ state_dict = state_dict["state_dict"]
+ except KeyError:
+ print("No key named state_dict. Directly loading from model.")
+ state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
+ self.model.load_state_dict(state_dict)
+ logger.info("Loaded weights from " + self.hparams.pretrained_model_path)
+
+ def on_batch_end(self):
+ # learning rate warmup
+ self.warmup_lr()
+
+ @rank_zero_only
+ def _symlink_logger(self):
+ # Keep track of which log file goes with which tensorboard log folder
+ tensorboard_log_dir = self.trainer.logger.log_dir
+ logger.info(f"Tensorboard logs: {tensorboard_log_dir}")
+ if os.path.exists(self.hparams.log_file):
+ _, log_name = os.path.split(self.hparams.log_file)
+ new_log_path = os.path.join(tensorboard_log_dir, log_name)
+
+ # when resuming training, symlink already exists
+ if not os.path.islink(new_log_path):
+ os.symlink(os.path.abspath(self.hparams.log_file), new_log_path)
+
+ def warmup_lr(self):
+ # get initial learning rate at step 0
+ if self.current_step == 0:
+ for param_group in self.optimizers().optimizer.param_groups:
+ self.peak_lr = param_group["lr"]
+
+ self.current_step += 1
+ if getattr(self.hparams, "warmup_steps", 0) >= self.current_step:
+ for param_group in self.optimizers().optimizer.param_groups:
+ param_group["lr"] = self.peak_lr * self.current_step / self.hparams.warmup_steps
+
+ def on_train_start(self):
+ self._symlink_logger()
+ self.load_pretrained_weight()
+
+ def forward(self, x):
+ return self.model(x)
+
+ def _step(self, batch):
+ input_features, target_dict = batch
+ input_features, lens = input_features
+
+ y = self.forward(input_features) # (batch, frame, freq) -> (batch, frame, freq, num_src)
+
+ loss = self.loss(y, target_dict, device=self.device, training=self.model.training)
+ return loss
+
+ def training_step(self, batch, batch_idx):
+ loss = self._step(batch)
+ loss_for_logging = {}
+ for k, v in loss.items():
+ loss_for_logging[f"train/{k}"] = v
+ self.log_dict(loss_for_logging, on_step=True, on_epoch=True, sync_dist=True)
+
+ self.on_batch_end()
+ return loss["loss"]
+
+ def validation_step(self, batch, batch_idx):
+ loss = self._step(batch)
+ loss_for_logging = {}
+ for k, v in loss.items():
+ loss_for_logging[f"val/{k}"] = v
+ self.log_dict(loss_for_logging, on_epoch=True, sync_dist=True)
+
+ return loss["loss"]
+
+ def test_step(self, batch, batch_idx):
+ input_features, target_dict = batch
+ input_features, lens = input_features
+ sample_rate = self.hparams.dataloading_conf["sr"]
+
+ est = self.forward(input_features) # (batch, frame, freq) -> (batch, frame, freq, num_src)
+
+ # apply FCP
+ est = self.loss.filtering_func(est, input_features)
+ est = est[..., self.loss.ref_channel, :]
+
+ # TF-domain -> time-domain by iSTFT
+ est = istft_4dim(est, **self.loss.stft_conf)[0].T
+
+ # reference signal
+ ref = target_dict["y_srcs"]["reverb"][0][0, ..., self.loss.ref_channel, :].T
+
+ # compute metrics
+ m = min(ref.shape[-1], est.shape[-1])
+ sisnr, perm = fast_bss_eval.si_sdr(ref[..., :m], est[..., :m], return_perm=True)
+ sisnr = sisnr.mean().cpu().numpy()
+ perm = perm.cpu().numpy()
+
+ sdr = fast_bss_eval.sdr(ref, est).mean().cpu().numpy()
+
+ ref, est = ref.cpu().numpy(), est.cpu().numpy()
+ pesq_score = 0.0
+ for i, p in enumerate(perm):
+ pesq_score += pesq(sample_rate, ref[i], est[p], mode="nb")
+ pesq_score /= i + 1
+
+ result = {
+ "test/sisnr": float(sisnr),
+ "test/sdr": float(sdr),
+ "test/pesq": float(pesq_score),
+ }
+ self.log_dict(result, on_epoch=True)
+
+ return result
+
+ def configure_optimizers(self):
+ optimizer = optim.Adam(self.parameters(), **self.hparams.optimizer_conf)
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, **self.hparams.scheduler_conf)
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": scheduler,
+ "monitor": "val/loss",
+ }
+
+ def _init_fn(self, worker_id):
+ random.seed(self.hparams.seed + worker_id)
+ np.random.seed(self.hparams.seed + worker_id)
+ torch.manual_seed(self.hparams.seed + worker_id)
+
+ def _get_data_loader(self, partition):
+ shuffle = self.hparams.shuffle if partition == "tr" else None
+ if partition == "tr":
+ batch_size = self.hparams.batch_size
+ elif partition == "cv":
+ batch_size = self.hparams.val_batch_size
+ else:
+ batch_size = 1
+
+ d = dataset_creator(self.hparams, self.data_path, partition)
+
+ if getattr(d, "running_eras", False):
+ collate_fn = collate_seq_eras
+ else:
+ collate_fn = collate_seq
+
+ return data.DataLoader(
+ d,
+ batch_size,
+ collate_fn=collate_fn,
+ shuffle=shuffle,
+ num_workers=self.hparams.num_workers,
+ worker_init_fn=self._init_fn,
+ )
+
+ def train_dataloader(self):
+ return self._get_data_loader("tr")
+
+ def val_dataloader(self):
+ return self._get_data_loader("cv")
+
+ def test_dataloader(self):
+ return self._get_data_loader("tt")
diff --git a/loss_functions/__init__.py b/loss_functions/__init__.py
new file mode 100644
index 0000000..8fccc03
--- /dev/null
+++ b/loss_functions/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
diff --git a/loss_functions/complex.py b/loss_functions/complex.py
new file mode 100644
index 0000000..b484916
--- /dev/null
+++ b/loss_functions/complex.py
@@ -0,0 +1,65 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+from loss_functions.general import pit_from_pairwise_loss
+
+
+def complex_l1(est, ref, permutation_search=False, reduction="mean_all", **kwargs):
+ """L1 loss on complex stft representations.
+ L1 loss on real part, imaginary part, and magnitude are computed.
+
+ Parameters
+ ----------
+ est: torch.Tensor, (n_batch, n_frame, n_freq, n_src)
+ Estimated (separated) signals.
+ ref: torch.Tensor, (n_batch, n_frame, n_freq, n_src)
+ Reference (ground-truth) signals.
+ permutation_search: bool
+ Whether to do permutation search between `est` and `ref`.
+ reduction: str
+ Argument for controlling the shape of returned tensor.
+ `keep_batchdim` returns a tensor with (n_batch, ),
+ `mean_all` returns a tensor with (),
+ and else, just return a tensor as it is.
+
+ Returns
+ ----------
+ loss: torch.Tensor,
+ Loss value.
+ """
+ # sometimes ref is a tuple of a tensor and a length
+ if isinstance(ref, tuple):
+ ref = ref[0]
+ normalizer = abs(ref).sum(dim=(1, 2)) # (batch, n_src)
+
+ # expand dimension for getting pairwise loss
+ if permutation_search:
+ est = est.unsqueeze(-1)
+ ref = ref.unsqueeze(-2)
+ normalizer = normalizer.unsqueeze(-2)
+
+ # loss
+ real_l1 = abs(ref.real - est.real).sum(dim=(1, 2))
+ imag_l1 = abs(ref.imag - est.imag).sum(dim=(1, 2))
+ abs_l1 = abs(abs(ref) - abs(est)).sum(dim=(1, 2))
+ loss = real_l1 + imag_l1 + abs_l1
+
+ loss = loss / normalizer
+ if permutation_search:
+ # compute loss with brute-force optimal permutation search
+ loss = pit_from_pairwise_loss(loss)
+ loss = loss / ref.shape[-1]
+
+ if reduction is None:
+ return loss
+ elif reduction == "keep_batchdim":
+ if loss.ndim == 1:
+ return loss
+ else:
+ return loss.mean(dim=-1)
+ elif reduction == "mean_all":
+ return loss.mean()
+ else:
+ raise RuntimeError(f"Choose proper reduction: {reduction}")
diff --git a/loss_functions/general.py b/loss_functions/general.py
new file mode 100644
index 0000000..4ee9cab
--- /dev/null
+++ b/loss_functions/general.py
@@ -0,0 +1,41 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+import itertools
+
+import torch
+
+
+def perms(num_sources):
+ return list(itertools.permutations(range(num_sources)))
+
+
+def pit_from_pairwise_loss(pw_dist_mat, reduction="sum", return_perm=False):
+ n_batch, n_src, _ = pw_dist_mat.shape
+ perm_mat = pw_dist_mat.new_tensor(perms(n_src), dtype=torch.long) # [n_perm, n_src]
+ ind = pw_dist_mat.new_tensor(range(n_src), dtype=torch.long).unsqueeze(0) # [1, n_src]
+ expanded_perm_dist_mat = pw_dist_mat[:, ind, perm_mat] # [n_batch, n_perm, n_src]
+ perm_dist_mat = torch.sum(expanded_perm_dist_mat, dim=2) # [n_batch, n_perm]
+ min_loss_perm, min_inds = torch.min(perm_dist_mat, dim=1) # [n_batch]
+
+ if return_perm:
+ opt_perm, perm_mat = torch.broadcast_tensors(
+ min_inds[:, None, None], perm_mat[None]
+ ) # [n_batch, n_perm, n_src]
+ opt_perm = torch.gather(perm_mat, 1, opt_perm[:, [0]])[:, 0]
+
+ if reduction == "sum":
+ if return_perm:
+ return min_loss_perm, opt_perm
+ else:
+ return min_loss_perm
+ else:
+ # sometimes we want (batch x n_src) loss matrix
+ min_inds = min_inds[:, None, None].tile(1, 1, n_src)
+ min_loss = torch.gather(expanded_perm_dist_mat, 1, min_inds)
+ if return_perm:
+ return min_loss[:, 0, :], opt_perm # (n_batch, n_src)
+ else:
+ return min_loss
diff --git a/loss_functions/ras_loss.py b/loss_functions/ras_loss.py
new file mode 100644
index 0000000..784fab7
--- /dev/null
+++ b/loss_functions/ras_loss.py
@@ -0,0 +1,295 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+from functools import partial
+from typing import Dict
+
+import torch
+from fast_bss_eval import si_sdr
+
+from loss_functions.complex import complex_l1
+from utils.audio_utils import istft_4dim, stft_3dim
+from utils.forward_convolutive_prediction import forward_convolutive_prediction
+
+loss_funcs = {"complex_l1": complex_l1}
+
+
+class RASLoss(object):
+ """
+ Reverberation as Supervision (RAS) loss [1, 2].
+ This class also supports over-determined conditions (UNSSOR) [3].
+
+
+ Parameters
+ ----------
+ loss_func: str
+ Name of loss function. Only "complex_l1" is supported now.
+ future_taps: int
+ The number of future taps in FCP.
+ past_taps: int
+ The number of past taps in FCP.
+ ref_channel_loss_weight: float
+ Loss weight on the reference channel (input channel of separation model).
+ In ERAS, 0.0 is recommended.
+ isms_loss_weight: float
+ Loss weight of intra-source magnitude scattering (ISMS) loss.
+ icc_loss_weight: float
+ Loss weight of inter-channel consistency (ICC) loss.
+ ref_channel: int
+ Index of reference channel.
+ nonref_channel: int
+ Index of non-reference channel. Must be different from `ref_channel`.
+ unsupervised: bool
+ Either of unsupervised or supervised.
+ supervised_loss_type: str
+ How to compute supervised loss when doing (semi-)supervised learning.
+ before_filtering, after_filtering_ref_channel, or after_filtering_nonref_channel.
+ Supervised loss is computed when doing unsupervised learning on the validation set
+ to monitor the performance.
+ stft_conf: dict
+ Dictionary containing STFT parameters.
+
+ References
+ ----------
+ [1]: Rohith Aralikatti, Christoph Boeddeker, Gordon Wichern, Aswin Subramanian, and Jonathan Le Roux,
+ "Reverberation as Supervision for Speech Separation," Proc. ICASSP, 2023.
+
+ [2]: Kohei Saijo, Gordon Wichern, François G. Germain, Zexu Pan, and Jonathan Le Roux,
+ "Enhanced Reverberation as Supervision for Unsupervised Speech Separation," Proc. Interspeech, 2024.
+
+ [3]: Zhong-Qiu Wang and Shinji Watanabe, "UNSSOR: unsupervised neural speech separation
+ by leveraging over-determined training mixtures," Proc. NeurIPS, 2023.
+ """
+
+ def __init__(
+ self,
+ loss_func: str = "complex_l1",
+ future_taps: int = 1,
+ past_taps: int = 19,
+ ref_channel_loss_weight: float = 0.0,
+ isms_loss_weight: float = 0.0,
+ icc_loss_weight: float = 0.0,
+ ref_channel: int = 0,
+ nonref_channel: int = 1,
+ unsupervised: bool = True,
+ supervised_loss_type: str = "after_filtering_ref_channel",
+ stft_conf: Dict = None,
+ ):
+ assert loss_func in loss_funcs, loss_func
+
+ self.loss_func = loss_funcs[loss_func]
+ self.ref_channel = ref_channel
+ self.nonref_channel = nonref_channel
+ self.unsupervised = unsupervised
+ self.supervised_loss_type = supervised_loss_type
+ self.stft_conf = stft_conf
+
+ # loss weights
+ self.ref_channel_loss_weight = ref_channel_loss_weight
+ self.isms_loss_weight = isms_loss_weight
+ self.icc_loss_weight = icc_loss_weight
+
+ if self.icc_loss_weight > 0:
+ self.icc_loss = partial(inter_channel_consistency_loss, loss_func=self.loss_func)
+
+ # define FCP used in the forward path
+ self.filtering_func = partial(
+ forward_convolutive_prediction,
+ past_taps=past_taps,
+ future_taps=future_taps,
+ )
+
+ def __call__(self, nn_outputs: torch.Tensor, targets: Dict, **kwargs):
+ # mixure signal used as supervision in RAS or UNSSOR
+ mix = targets["y_mix_stft"][0]
+
+ n_batch, n_channels = mix.shape[0], mix.shape[-1]
+ unsup_loss = 0.0
+
+ est_src_filtered = self.filtering_func(nn_outputs, mix) # (n_batch, n_frames, n_freqs, n_chan, n_src)
+ est_mix = est_src_filtered.sum(dim=-1)
+
+ # compute RAS loss
+ ras_loss = self.loss_func(
+ est_mix,
+ mix,
+ permutation_search=False,
+ reduction=None,
+ ) # loss: (n_batch, n_chan)
+
+ assert ras_loss.shape == (
+ n_batch,
+ n_channels,
+ ), f"loss must be (batch x channel) but {ras_loss.shape}"
+
+ # compute ISMS loss
+ isms_loss = intra_source_magnitude_scattering_loss(est_src_filtered, mix, reduction=None)
+ assert ras_loss.shape == isms_loss.shape, "(loss must be (batch x channel)"
+ f"but {ras_loss.shape} and {isms_loss.shape})"
+ unsup_loss = ras_loss + self.isms_loss_weight * isms_loss
+
+ # weighting loss on reference channel
+ unsup_loss[:, self.ref_channel] *= self.ref_channel_loss_weight
+ unsup_loss = unsup_loss.sum(dim=-1)
+
+ # inter-source consistency loss
+ training = unsup_loss.requires_grad
+
+ # Currently only one of L or R is loaded during validation
+ # and we cannot compute inter-source loss on dev set
+ if self.icc_loss_weight > 0 and training:
+ icc_loss = self.icc_loss(est_src_filtered)
+ unsup_loss += self.icc_loss_weight * icc_loss
+
+ # in supervised case or on validation set, we compute the supervised loss
+ if training and self.unsupervised:
+ loss = unsup_loss.mean()
+ sup_loss = torch.zeros_like(unsup_loss) # for logging
+ else:
+ # pick some channels for SI-SDR evaluation
+ est_src_ref_channel = est_src_filtered[..., self.ref_channel, :]
+ est_src_nonref_channel = est_src_filtered[..., self.nonref_channel, :]
+
+ targets = targets["y_srcs"]
+
+ if self.supervised_loss_type == "before_filtering":
+ est = nn_outputs
+ tgt = targets["reverb"][0][..., self.ref_channel, :]
+ elif self.supervised_loss_type == "after_filtering_ref_channel":
+ est = est_src_ref_channel
+ tgt = targets["reverb"][0][..., self.ref_channel, :]
+ elif self.supervised_loss_type == "after_filtering_nonref_channel":
+ est = est_src_nonref_channel
+ tgt = targets["reverb"][0][..., self.nonref_channel, :]
+
+ tgt = stft_3dim(tgt, **self.stft_conf)
+
+ sup_loss = self.loss_func(
+ est[:, : tgt.shape[1]],
+ tgt,
+ permutation_search=True,
+ reduction="keep_batchdim",
+ )
+
+ uloss = unsup_loss.mean()
+ sloss = sup_loss.mean()
+ loss = uloss if self.unsupervised else sloss
+
+ # metrics for logging and backprop
+ metrics = {
+ "loss": loss,
+ "ras_loss": ras_loss.mean(),
+ "isms_loss": isms_loss.mean(),
+ "unsup_loss": unsup_loss.mean(),
+ "sup_loss": sup_loss.mean(),
+ }
+
+ # compute SI-SDR for logging
+ if not training:
+ # istft to get time-domain signals when using tf-domain loss
+ est_src_ref_channel = istft_4dim(est_src_ref_channel, **self.stft_conf).transpose(-1, -2)
+ est_src_nonref_channel = istft_4dim(est_src_nonref_channel, **self.stft_conf).transpose(-1, -2)
+
+ ref = targets["reverb"][0].transpose(-1, 1)
+ m = min(est_src_ref_channel.shape[-1], ref.shape[-1])
+ metrics["sisnr_ref_channel"] = si_sdr(ref[..., self.ref_channel, :m], est_src_ref_channel[..., :m]).mean()
+ metrics["sisnr_nonref_channel"] = si_sdr(
+ ref[..., self.nonref_channel, :m], est_src_nonref_channel[..., :m]
+ ).mean()
+
+ return metrics
+
+
+def intra_source_magnitude_scattering_loss(est, mix, reduction=None, eps=1e-8):
+ """Intra source magnitude scattering loss (ISMS) proposed in [3] (Eq.10).
+
+ Parameters
+ ----------
+ est: torch.Tensor, (..., n_frame, n_freq, n_chan, n_src)
+ Separation estimates AFTER applying FCP
+ mix: torch.Tensor, (..., n_frame, n_freq, n_chan)
+ Mixture observed at the microphone
+ reduction: str
+ How to aggregate the loss values.
+ Must be chosen from {None, "keep_batchdim", "mean_all"}
+
+ Returns
+ ----------
+ loss: torch.Tensor
+ ISMS loss value. Shape depends on the specified reduction.
+ """
+ mix_logmag = torch.log(abs(mix) + eps)
+ est_logmag = torch.log(abs(est) + eps)
+
+ mix_var = torch.var(mix_logmag, dim=-2).sum(dim=-2)
+ est_var = torch.var(est_logmag, dim=-3).mean(dim=-1).sum(dim=-2)
+
+ loss = est_var / mix_var # (batch, n_chan)
+
+ if reduction is None:
+ return loss
+ elif reduction == "keep_batchdim":
+ return loss.mean(dim=-1)
+ elif reduction == "mean_all":
+ return loss.mean()
+ else:
+ raise RuntimeError(f"Choose proper reduction: {reduction}")
+
+
+def inter_channel_consistency_loss(est: torch.Tensor, loss_func: callable):
+ """Inter-channel consistency (ICC) loss proposed in [2].
+
+ In ERAS, each component in the mini-batch `est` is:
+ est[0]: separated signals from mix1 at 1st channel and mapped to [1st, 2nd] channels
+ est[1]: separated signals from mix1 at 2nd channel and mapped to [2nd, 1st] channels
+ est[2]: separated signals from mix2 at 1st channel and mapped to [1st, 2nd] channels
+ est[3]: separated signals from mix2 at 2nd channel and mapped to [2nd, 1st] channels
+ ...
+ est[B-2]: separated signals from mix{B/2} at 1st channel and mapped to [1st, 2nd] channels
+ est[B-1]: separated signals from mix{B/2} at 2nd channel and mapped to [2nd, 1st] channels
+
+ Note that mix{1,...,B/2} are different (totally unrelated) mixtures in the same mini-batch.
+
+ The ICC loss uses signals mapped to the reference microphone as the pseudo-targets of
+ those mapped to the non-reference microphone, e.g.,
+ - est[0][0] is the pseudo-target of est[1][1]
+ - est[1][0] is the pseudo-target of est[0][1]
+
+
+ Parameters
+ ----------
+ est: torch.Tensor
+ Separation outputs after FCP (n_batch, n_frames, n_freqs, n_chan, n_src)
+ loss_func: callable
+ Loss function
+
+ Returns
+ ----------
+ loss: torch.Tensor,
+ The inter-channel consistency loss value.
+ """
+ n_batch, n_chan = est.shape[0], est.shape[-2]
+ assert n_chan == 2
+
+ # the first channel in `n_chan` dimension is the one mapped to the reference microphone
+ # while the second is mapped to the non-reference microphone.
+ ref = est[..., 0, :].clone().detach() # signals mapped to reference microphone
+ est = est[..., 1, :] # signals mapped to non-reference microphone
+
+ # change the batch order of `est` to [1, 0, 3, 2, ...]
+ # to align the batch order of ref and est
+ p_tmp = torch.arange(n_batch)
+ p = p_tmp.clone()
+ p[0::2], p[1::2] = p_tmp[1::2], p_tmp[0::2] # p: [1, 0, 3, 2, ...]
+ est = est[p]
+
+ # compute loss
+ loss = loss_func(
+ est,
+ ref,
+ permutation_search=True,
+ reduction=None,
+ )
+ return loss.mean()
diff --git a/nets/__init__.py b/nets/__init__.py
new file mode 100644
index 0000000..8fccc03
--- /dev/null
+++ b/nets/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
diff --git a/nets/build_model.py b/nets/build_model.py
new file mode 100644
index 0000000..36f7c71
--- /dev/null
+++ b/nets/build_model.py
@@ -0,0 +1,15 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+from .tfgridnetv2 import TFGridNetV2
+
+
+def build_model(model_name, model_conf):
+ if model_name == "tfgridnetv2":
+ model = TFGridNetV2(**model_conf)
+ else:
+ raise ValueError("Model type {} not currently supported.".format(model_name))
+
+ return model
diff --git a/nets/tfgridnetv2.py b/nets/tfgridnetv2.py
new file mode 100644
index 0000000..576ba27
--- /dev/null
+++ b/nets/tfgridnetv2.py
@@ -0,0 +1,351 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+# Copyright (C) 2023 ESPnet Developers
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+# SPDX-License-Identifier: Apache-2.0
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import init
+from torch.nn.parameter import Parameter
+
+"""This script is adapted from ESPnet (https://github.com/espnet/espnet).
+Part of the code is modified for our use.
+https://github.com/espnet/espnet/blob/master/espnet2/enh/separator/tfgridnetv2_separator.py
+"""
+
+
+class TFGridNetV2(nn.Module):
+ """Offline TFGridNetV2. Compared with TFGridNet, TFGridNetV2 speeds up the code
+ by vectorizing multiple heads in self-attention, and better dealing with
+ Deconv1D in each intra- and inter-block when emb_ks == emb_hs.
+
+ Reference:
+ [1] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe,
+ "TF-GridNet: Integrating Full- and Sub-Band Modeling for Speech Separation",
+ in TASLP, 2023.
+ [2] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe,
+ "TF-GridNet: Making Time-Frequency Domain Models Great Again for Monaural
+ Speaker Separation", in ICASSP, 2023.
+
+ Args:
+ fft_size: fft size.
+ n_imics: number of microphone channels (only fixed-array geometry supported).
+ n_srcs: number of output sources/speakers.
+ n_layers: number of TFGridNetV2 blocks.
+ lstm_hidden_units: number of hidden units in LSTM.
+ attn_n_head: number of heads in self-attention
+ attn_approx_qk_dim: approximate dimention of frame-level key and value tensors
+ emb_dim: embedding dimension
+ emb_ks: kernel size for unfolding and deconv1D
+ emb_hs: hop size for unfolding and deconv1D
+ eps: small epsilon for normalization layers.
+ """
+
+ def __init__(
+ self,
+ fft_size=256,
+ n_imics=1,
+ n_srcs=2,
+ n_layers=4,
+ lstm_hidden_units=192,
+ attn_n_head=4,
+ attn_approx_qk_dim=512,
+ emb_dim=48,
+ emb_ks=4,
+ emb_hs=1,
+ eps=1.0e-5,
+ ):
+ super().__init__()
+ self.n_srcs = n_srcs
+ self.n_layers = n_layers
+ self.n_imics = n_imics
+
+ n_freqs = fft_size // 2 + 1
+
+ t_ksize = 3
+ ks, padding = (t_ksize, 3), (t_ksize // 2, 1)
+ self.conv = nn.Sequential(
+ nn.Conv2d(2 * n_imics, emb_dim, ks, padding=padding),
+ nn.GroupNorm(1, emb_dim, eps=eps),
+ )
+
+ self.blocks = nn.ModuleList([])
+ for _ in range(n_layers):
+ self.blocks.append(
+ GridNetV2Block(
+ emb_dim,
+ emb_ks,
+ emb_hs,
+ n_freqs,
+ lstm_hidden_units,
+ n_head=attn_n_head,
+ approx_qk_dim=attn_approx_qk_dim,
+ eps=eps,
+ )
+ )
+
+ self.deconv = nn.ConvTranspose2d(emb_dim, n_srcs * 2, ks, padding=padding)
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ ) -> torch.Tensor:
+ """Forward.
+
+ Args:
+ input (torch.Tensor): batched multi-channel audio tensor with
+ M audio channels in TF-domain [B, T, F, M]
+
+ Returns:
+ batch (torch.Tensor): batched monaural audio tensor with
+ N separated signals in TF-domain [B, T, F, N]
+ """
+
+ # using only specified number of channels
+ batch0 = input[..., : self.n_imics]
+
+ batch = torch.movedim(batch0, 3, 1) # [B, M, T, F]
+ batch = torch.cat((batch.real, batch.imag), dim=1) # [B, 2*M, T, F]
+ n_batch, _, n_frames, n_freqs = batch.shape
+
+ batch = self.conv(batch) # [B, -1, T, F]
+
+ for ii in range(self.n_layers):
+ batch = self.blocks[ii](batch) # [B, -1, T, F]
+
+ batch = self.deconv(batch) # [B, n_srcs*2, T, F]
+
+ batch = batch.view([n_batch, self.n_srcs, 2, n_frames, n_freqs])
+ batch = torch.complex(batch[:, :, 0], batch[:, :, 1])
+ return torch.movedim(batch, 1, 3)
+
+ @property
+ def num_spk(self):
+ return self.n_srcs
+
+ @staticmethod
+ def pad2(input_tensor, target_len):
+ input_tensor = torch.nn.functional.pad(input_tensor, (0, target_len - input_tensor.shape[-1]))
+ return input_tensor
+
+
+class GridNetV2Block(nn.Module):
+ def __getitem__(self, key):
+ return getattr(self, key)
+
+ def __init__(
+ self,
+ emb_dim,
+ emb_ks,
+ emb_hs,
+ n_freqs,
+ hidden_channels,
+ n_head=4,
+ approx_qk_dim=512,
+ eps=1e-5,
+ ):
+ super().__init__()
+
+ in_channels = emb_dim * emb_ks
+
+ self.intra_norm = nn.LayerNorm(emb_dim, eps=eps)
+ self.intra_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=True)
+ if emb_ks == emb_hs:
+ self.intra_linear = nn.Linear(hidden_channels * 2, in_channels)
+ else:
+ self.intra_linear = nn.ConvTranspose1d(hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs)
+
+ self.inter_norm = nn.LayerNorm(emb_dim, eps=eps)
+ self.inter_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=True)
+ if emb_ks == emb_hs:
+ self.inter_linear = nn.Linear(hidden_channels * 2, in_channels)
+ else:
+ self.inter_linear = nn.ConvTranspose1d(hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs)
+
+ E = math.ceil(approx_qk_dim * 1.0 / n_freqs) # approx_qk_dim is only approximate
+ assert emb_dim % n_head == 0
+
+ self.add_module("attn_conv_Q", nn.Conv2d(emb_dim, n_head * E, 1))
+ self.add_module(
+ "attn_norm_Q",
+ AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps),
+ )
+
+ self.add_module("attn_conv_K", nn.Conv2d(emb_dim, n_head * E, 1))
+ self.add_module(
+ "attn_norm_K",
+ AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps),
+ )
+
+ self.add_module("attn_conv_V", nn.Conv2d(emb_dim, n_head * emb_dim // n_head, 1))
+ self.add_module(
+ "attn_norm_V",
+ AllHeadPReLULayerNormalization4DCF((n_head, emb_dim // n_head, n_freqs), eps=eps),
+ )
+
+ self.add_module(
+ "attn_concat_proj",
+ nn.Sequential(
+ nn.Conv2d(emb_dim, emb_dim, 1),
+ nn.PReLU(),
+ LayerNormalization4DCF((emb_dim, n_freqs), eps=eps),
+ ),
+ )
+
+ self.emb_dim = emb_dim
+ self.emb_ks = emb_ks
+ self.emb_hs = emb_hs
+ self.n_head = n_head
+
+ def forward(self, x):
+ """GridNetV2Block Forward.
+
+ Args:
+ x: [B, C, T, Q]
+ out: [B, C, T, Q]
+ """
+ B, C, old_T, old_Q = x.shape
+
+ olp = self.emb_ks - self.emb_hs
+ T = math.ceil((old_T + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
+ Q = math.ceil((old_Q + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
+
+ x = x.permute(0, 2, 3, 1) # [B, old_T, old_Q, C]
+ x = F.pad(x, (0, 0, olp, Q - old_Q - olp, olp, T - old_T - olp)) # [B, T, Q, C]
+
+ # intra RNN
+ input_ = x
+ intra_rnn = self.intra_norm(input_) # [B, T, Q, C]
+ if self.emb_ks == self.emb_hs:
+ intra_rnn = intra_rnn.view([B * T, -1, self.emb_ks * C]) # [BT, Q//I, I*C]
+ intra_rnn, _ = self.intra_rnn(intra_rnn) # [BT, Q//I, H]
+ intra_rnn = self.intra_linear(intra_rnn) # [BT, Q//I, I*C]
+ intra_rnn = intra_rnn.view([B, T, Q, C])
+ else:
+ intra_rnn = intra_rnn.view([B * T, Q, C]) # [BT, Q, C]
+ intra_rnn = intra_rnn.transpose(1, 2) # [BT, C, Q]
+ intra_rnn = F.unfold(intra_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)) # [BT, C*I, -1]
+ intra_rnn = intra_rnn.transpose(1, 2) # [BT, -1, C*I]
+
+ intra_rnn, _ = self.intra_rnn(intra_rnn) # [BT, -1, H]
+
+ intra_rnn = intra_rnn.transpose(1, 2) # [BT, H, -1]
+ intra_rnn = self.intra_linear(intra_rnn) # [BT, C, Q]
+ intra_rnn = intra_rnn.view([B, T, C, Q])
+ intra_rnn = intra_rnn.transpose(-2, -1) # [B, T, Q, C]
+ intra_rnn = intra_rnn + input_ # [B, T, Q, C]
+
+ intra_rnn = intra_rnn.transpose(1, 2) # [B, Q, T, C]
+
+ # inter RNN
+ input_ = intra_rnn
+ inter_rnn = self.inter_norm(input_) # [B, Q, T, C]
+ if self.emb_ks == self.emb_hs:
+ inter_rnn = inter_rnn.view([B * Q, -1, self.emb_ks * C]) # [BQ, T//I, I*C]
+ inter_rnn, _ = self.inter_rnn(inter_rnn) # [BQ, T//I, H]
+ inter_rnn = self.inter_linear(inter_rnn) # [BQ, T//I, I*C]
+ inter_rnn = inter_rnn.view([B, Q, T, C])
+ else:
+ inter_rnn = inter_rnn.view(B * Q, T, C) # [BQ, T, C]
+ inter_rnn = inter_rnn.transpose(1, 2) # [BQ, C, T]
+ inter_rnn = F.unfold(inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)) # [BQ, C*I, -1]
+ inter_rnn = inter_rnn.transpose(1, 2) # [BQ, -1, C*I]
+
+ inter_rnn, _ = self.inter_rnn(inter_rnn) # [BQ, -1, H]
+
+ inter_rnn = inter_rnn.transpose(1, 2) # [BQ, H, -1]
+ inter_rnn = self.inter_linear(inter_rnn) # [BQ, C, T]
+ inter_rnn = inter_rnn.view([B, Q, C, T])
+ inter_rnn = inter_rnn.transpose(-2, -1) # [B, Q, T, C]
+ inter_rnn = inter_rnn + input_ # [B, Q, T, C]
+
+ inter_rnn = inter_rnn.permute(0, 3, 2, 1) # [B, C, T, Q]
+
+ inter_rnn = inter_rnn[..., olp : olp + old_T, olp : olp + old_Q]
+ batch = inter_rnn
+
+ Q = self["attn_norm_Q"](self["attn_conv_Q"](batch)) # [B, n_head, C, T, Q]
+ K = self["attn_norm_K"](self["attn_conv_K"](batch)) # [B, n_head, C, T, Q]
+ V = self["attn_norm_V"](self["attn_conv_V"](batch)) # [B, n_head, C, T, Q]
+ Q = Q.view(-1, *Q.shape[2:]) # [B*n_head, C, T, Q]
+ K = K.view(-1, *K.shape[2:]) # [B*n_head, C, T, Q]
+ V = V.view(-1, *V.shape[2:]) # [B*n_head, C, T, Q]
+
+ Q = Q.transpose(1, 2)
+ Q = Q.flatten(start_dim=2) # [B', T, C*Q]
+
+ K = K.transpose(2, 3)
+ K = K.contiguous().view([B * self.n_head, -1, old_T]) # [B', C*Q, T]
+
+ V = V.transpose(1, 2) # [B', T, C, Q]
+ old_shape = V.shape
+ V = V.flatten(start_dim=2) # [B', T, C*Q]
+ emb_dim = Q.shape[-1]
+
+ attn_mat = torch.matmul(Q, K) / (emb_dim**0.5) # [B', T, T]
+ attn_mat = F.softmax(attn_mat, dim=2) # [B', T, T]
+ V = torch.matmul(attn_mat, V) # [B', T, C*Q]
+
+ V = V.reshape(old_shape) # [B', T, C, Q]
+ V = V.transpose(1, 2) # [B', C, T, Q]
+ emb_dim = V.shape[1]
+
+ batch = V.contiguous().view([B, self.n_head * emb_dim, old_T, old_Q]) # [B, C, T, Q])
+ batch = self["attn_concat_proj"](batch) # [B, C, T, Q])
+
+ out = batch + inter_rnn
+ return out
+
+
+class LayerNormalization4DCF(nn.Module):
+ def __init__(self, input_dimension, eps=1e-5):
+ super().__init__()
+ assert len(input_dimension) == 2
+ param_size = [1, input_dimension[0], 1, input_dimension[1]]
+ self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
+ self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
+ init.ones_(self.gamma)
+ init.zeros_(self.beta)
+ self.eps = eps
+
+ def forward(self, x):
+ if x.ndim == 4:
+ stat_dim = (1, 3)
+ else:
+ raise ValueError("Expected x to have 4 dimensions, but got {}".format(x.ndim))
+ mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,1]
+ std_ = torch.sqrt(x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps) # [B,1,T,F]
+ x_hat = ((x - mu_) / std_) * self.gamma + self.beta
+ return x_hat
+
+
+class AllHeadPReLULayerNormalization4DCF(nn.Module):
+ def __init__(self, input_dimension, eps=1e-5):
+ super().__init__()
+ assert len(input_dimension) == 3
+ H, E, n_freqs = input_dimension
+ param_size = [1, H, E, 1, n_freqs]
+ self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
+ self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
+ init.ones_(self.gamma)
+ init.zeros_(self.beta)
+ self.act = nn.PReLU(num_parameters=H, init=0.25)
+ self.eps = eps
+ self.H = H
+ self.E = E
+ self.n_freqs = n_freqs
+
+ def forward(self, x):
+ assert x.ndim == 4
+ B, _, T, _ = x.shape
+ x = x.view([B, self.H, self.E, T, self.n_freqs])
+ x = self.act(x) # [B,H,E,T,F]
+ stat_dim = (2, 4)
+ mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,H,1,T,1]
+ std_ = torch.sqrt(x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps) # [B,H,1,T,1]
+ x = ((x - mu_) / std_) * self.gamma + self.beta # [B,H,E,T,F]
+ return x
diff --git a/requirements-dev.txt b/requirements-dev.txt
new file mode 100644
index 0000000..c8b4a2c
--- /dev/null
+++ b/requirements-dev.txt
@@ -0,0 +1,8 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+pre-commit
+black>=22
+flake8
+pytest
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..8fccc03
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
diff --git a/tests/test_fcp.py b/tests/test_fcp.py
new file mode 100644
index 0000000..6c5e157
--- /dev/null
+++ b/tests/test_fcp.py
@@ -0,0 +1,33 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+import pytest
+import torch
+
+from utils.forward_convolutive_prediction import forward_convolutive_prediction as fcp
+from utils.forward_convolutive_prediction import stack_past_and_future_taps
+
+
+@pytest.mark.parametrize("past_tap", [1, 5, 19])
+@pytest.mark.parametrize("future_tap", [0, 1])
+def test_stack_past_and_future_taps_forward(past_tap, future_tap):
+ n_batch, n_frame, n_freq, n_src = 1, 50, 65, 2
+ input = torch.randn((n_batch, n_frame, n_freq, n_src), dtype=torch.complex64)
+
+ padded = stack_past_and_future_taps(input, past_tap, future_tap)
+ assert padded.shape == (n_batch, n_frame, past_tap + future_tap + 1, n_freq, n_src)
+
+
+@pytest.mark.parametrize("past_tap", [1, 5, 19])
+@pytest.mark.parametrize("future_tap", [0, 1])
+@pytest.mark.parametrize("n_chan", [1, 2, 6])
+@pytest.mark.parametrize("n_src", [1, 2])
+def test_fcp_forward(past_tap, future_tap, n_chan, n_src):
+ n_batch, n_frame, n_freq = 1, 50, 65
+ est = torch.randn((n_batch, n_frame, n_freq, n_src), dtype=torch.complex64)
+ mix = torch.randn((n_batch, n_frame, n_freq, n_chan), dtype=torch.complex64)
+
+ est_filtered = fcp(est, mix, past_tap, future_tap)
+ assert est_filtered.shape == (n_batch, n_frame, n_freq, n_chan, n_src)
diff --git a/tests/test_loss.py b/tests/test_loss.py
new file mode 100644
index 0000000..36db6a3
--- /dev/null
+++ b/tests/test_loss.py
@@ -0,0 +1,70 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+import pytest
+import torch
+
+from loss_functions.complex import complex_l1
+from loss_functions.ras_loss import inter_channel_consistency_loss as icc_loss
+from loss_functions.ras_loss import intra_source_magnitude_scattering_loss as isms_loss
+
+
+@pytest.mark.parametrize("permutation_search", [True, False])
+@pytest.mark.parametrize("reduction", ["mean_all", "keep_batchdim", None])
+def test_complex_l1(permutation_search, reduction):
+ n_batch, n_frame, n_freq, n_src = 1, 50, 65, 2
+ est = torch.randn((n_batch, n_frame, n_freq, n_src), dtype=torch.complex64)
+ ref = torch.randn((n_batch, n_frame, n_freq, n_src), dtype=torch.complex64)
+
+ loss = complex_l1(est, ref, permutation_search=permutation_search, reduction=reduction)
+
+ if reduction is None:
+ if permutation_search:
+ assert loss.shape == (n_batch,)
+ else:
+ assert loss.shape == (n_batch, n_src)
+ elif reduction == "keep_batchdim":
+ assert loss.shape == (n_batch,)
+ else:
+ assert loss.shape == torch.Size([])
+
+
+@pytest.mark.parametrize("reduction", ["mean_all", "keep_batchdim", None])
+def test_isms_loss(reduction):
+ n_batch, n_frame, n_freq, n_chan, n_src = 1, 50, 65, 2, 2
+ est = torch.randn((n_batch, n_frame, n_freq, n_chan, n_src), dtype=torch.complex64)
+ mix = torch.randn((n_batch, n_frame, n_freq, n_chan), dtype=torch.complex64)
+
+ loss = isms_loss(est, mix, reduction=reduction)
+
+ if reduction is None:
+ assert loss.shape == (n_batch, n_chan)
+ elif reduction == "keep_batchdim":
+ assert loss.shape == (n_batch,)
+ else:
+ assert loss.shape == torch.Size([])
+
+
+def test_icc_loss():
+ n_batch, n_frame, n_freq, n_chan, n_src = 1, 50, 65, 2, 2
+ est_ch1 = torch.randn((n_batch, n_frame, n_freq, n_chan, n_src), dtype=torch.complex64)
+ est_ch2 = torch.stack((est_ch1[..., 1, :], est_ch1[..., 0, :]), dim=-2)
+ est = torch.cat((est_ch1, est_ch2), dim=0)
+
+ loss = icc_loss(est, complex_l1)
+
+ assert loss.shape == torch.Size([])
+ assert loss == 0.0
+
+
+@pytest.mark.parametrize("n_chan", [1, 3, 6])
+def test_icc_loss_invalid_type(n_chan):
+ # n_chan != 2 raises the assertion error
+
+ n_batch, n_frame, n_freq, n_src = 2, 50, 65, 2
+ est = torch.randn((n_batch, n_frame, n_freq, n_chan, n_src), dtype=torch.complex64)
+
+ with pytest.raises(AssertionError):
+ icc_loss(est, complex_l1)
diff --git a/tests/test_tfgridnet.py b/tests/test_tfgridnet.py
new file mode 100644
index 0000000..9fb1c95
--- /dev/null
+++ b/tests/test_tfgridnet.py
@@ -0,0 +1,60 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+import pytest
+import torch
+
+from nets.tfgridnetv2 import TFGridNetV2
+
+
+@pytest.mark.parametrize("fft_size", [128, 256])
+@pytest.mark.parametrize("n_srcs", [1, 2])
+@pytest.mark.parametrize("n_imics", [1, 3, 6])
+@pytest.mark.parametrize("n_layers", [1, 4, 6])
+@pytest.mark.parametrize("lstm_hidden_units", [16])
+@pytest.mark.parametrize("attn_n_head", [1, 4])
+@pytest.mark.parametrize("attn_approx_qk_dim", [32])
+@pytest.mark.parametrize("emb_dim", [16])
+@pytest.mark.parametrize("emb_ks", [4])
+@pytest.mark.parametrize("emb_hs", [1, 4])
+@pytest.mark.parametrize("eps", [1.0e-5])
+def test_tfgridnetv2_forward_backward(
+ fft_size,
+ n_srcs,
+ n_imics,
+ n_layers,
+ lstm_hidden_units,
+ attn_n_head,
+ attn_approx_qk_dim,
+ emb_dim,
+ emb_ks,
+ emb_hs,
+ eps,
+):
+
+ model = TFGridNetV2(
+ fft_size=fft_size,
+ n_imics=n_imics,
+ n_srcs=n_srcs,
+ n_layers=n_layers,
+ lstm_hidden_units=lstm_hidden_units,
+ attn_n_head=attn_n_head,
+ attn_approx_qk_dim=attn_approx_qk_dim,
+ emb_dim=emb_dim,
+ emb_ks=emb_ks,
+ emb_hs=emb_hs,
+ eps=eps,
+ )
+ model.train()
+
+ n_freqs = fft_size // 2 + 1
+ n_batch, n_frames = 2, 18
+ real = torch.rand(n_batch, n_frames, n_freqs, n_imics)
+ imag = torch.rand(n_batch, n_frames, n_freqs, n_imics)
+ x = torch.complex(real, imag)
+
+ output = model(x)
+ assert output.shape == (n_batch, n_frames, n_freqs, n_srcs)
+ sum(output).abs().mean().backward()
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..5f75d73
--- /dev/null
+++ b/train.py
@@ -0,0 +1,81 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+import argparse
+from pathlib import Path
+
+import loguru
+import torch
+from pytorch_lightning import Trainer, seed_everything
+from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
+from pytorch_lightning.loggers import TensorBoardLogger
+
+from lightning_train import RASTrainingModule
+from utils.config import yaml_to_parser
+
+
+def main(args):
+
+ hparams = yaml_to_parser(args.config)
+ hparams = hparams.parse_args([])
+ exp_name = args.config.stem
+
+ seed_everything(hparams.seed, workers=True)
+
+ # some cuda configs
+ torch.backends.cuda.matmul.allow_tf32 = False
+ torch.backends.cudnn.allow_tf32 = True
+
+ logger = TensorBoardLogger(save_dir="exp", name="eras", version=exp_name)
+ ckpt_dir = Path(logger.log_dir) / "checkpoints"
+
+ model = RASTrainingModule(hparams, args.data_path)
+
+ if (ckpt_dir / "last.ckpt").exists():
+ # resume training from the latest checkpoint
+ ckpt_path = ckpt_dir / "last.ckpt"
+ skip_first_validation_loop = True
+ loguru.logger.info(f"Resume training from {str(ckpt_path)}")
+ elif getattr(hparams, "pretrained_model_path", None) is not None:
+ ckpt_path = None
+ skip_first_validation_loop = True
+ else:
+ print("Train from scratch")
+ ckpt_path = None
+ skip_first_validation_loop = False
+
+ ckpt_callback = ModelCheckpoint(**hparams.model_checkpoint)
+ callbacks = [LearningRateMonitor(logging_interval="epoch"), ckpt_callback]
+ if hparams.early_stopping is not None:
+ callbacks.append(EarlyStopping(**hparams.early_stopping))
+
+ trainer = Trainer(
+ logger=logger,
+ callbacks=callbacks,
+ enable_progress_bar=False,
+ deterministic=True,
+ devices=-1,
+ strategy="ddp",
+ **hparams.trainer_conf,
+ )
+
+ # validation epoch before training for debugging
+ if skip_first_validation_loop:
+ loguru.logger.info("Skip validating before train when resuming training")
+ else:
+ loguru.logger.info("Validating before train")
+ trainer.validate(model)
+
+ # training
+ loguru.logger.info("Finished initial validation")
+ trainer.fit(model, ckpt_path=ckpt_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=Path, required=True)
+ parser.add_argument("--data_path", type=Path, required=True)
+ args, other_options = parser.parse_known_args()
+ main(args)
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..8fccc03
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
diff --git a/utils/audio_utils.py b/utils/audio_utils.py
new file mode 100755
index 0000000..b932d15
--- /dev/null
+++ b/utils/audio_utils.py
@@ -0,0 +1,231 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+import torch
+import torch.nn.functional as F
+from scipy import signal
+
+
+def rotate_channels(input, ref_channel, channel_dim=-1):
+ """
+ Rotate channel dimension so that the given ref_channel comes first
+ """
+ output = input.transpose(-1, channel_dim)
+ output = torch.cat((output[..., ref_channel:], output[..., :ref_channel]), dim=-1)
+ output = output.transpose(-1, channel_dim)
+ assert input.shape == output.shape, (input.shape, output.shape)
+ return output
+
+
+def stft_padding(
+ input_signal,
+ n_fft,
+ window_length,
+ hop_length,
+ center_pad=True,
+ end_pad=True,
+ pad_mode="constant",
+):
+ signal_length = input_signal.shape[-1]
+ pad_start = 0
+ pad_end = 0
+ if center_pad:
+ # Do center padding here instead of torch.stft, since we need to do it anyway because they don't
+ # support end padding.
+ pad_start = int(n_fft // 2)
+ pad_end = pad_start
+ signal_length = signal_length + pad_start + pad_end
+ if end_pad:
+ # from scipy.signal.stft
+ # Pad to integer number of windowed segments
+ # I.e. make signal_length = window_length + (nseg-1)*hop_length, with integer nseg
+ nadd = (-(signal_length - window_length) % hop_length) % window_length
+ pad_end += nadd
+
+ # do the padding
+ signal_dim = input_signal.dim()
+ extended_shape = [1] * (3 - signal_dim) + list(input_signal.size())
+ input_signal = F.pad(input_signal.view(extended_shape), (pad_start, pad_end), pad_mode)
+ input_signal = input_signal.view(input_signal.shape[-signal_dim:])
+ return input_signal
+
+
+def _normalize_options(window, method_str):
+ normalize_flag = False
+ if method_str == "window":
+ window = window / torch.sum(window)
+ elif method_str == "default":
+ normalize_flag = True
+ return window, normalize_flag
+
+
+def do_stft(
+ signal,
+ window_length,
+ hop_length,
+ fft_size=None,
+ normalize="default",
+ window_type="sqrt_hann",
+):
+ """
+ Wrap torch.stft, and return transposed spectrogram for pytorch input
+
+ :param signal: tensor of shape (n_channels, n_samples) or (n_samples)
+ :param window_length: int size of stft window
+ :param hop_length: int stride of stft window
+ :param fft_size: int geq to window_length, if None set to window_length
+ :param normalize: string for determining how to normalize stft output.
+ "window": divide the window by its sum. This will give amplitudes of components that match
+ time domain components, but small values could cause numerical issues
+ "default": default pytorch stft normalization divides by sqrt of window length
+ None: no normalization of stft outputs
+ :param window_type: string indicating the window type ("sqrt_hann", "hann", "blackman", "blackmanharris", "hamming")
+ :return: complex tensor of shape (n_frames, n_frequencies, n_channels) or (n_frames, n_frequencies)
+ """
+ if fft_size is None:
+ fft_size = window_length
+ signal = stft_padding(signal, fft_size, window_length, hop_length)
+ window = get_window(window=window_type, window_length=window_length, device=signal.device)
+ window, normalize_flag = _normalize_options(window, normalize)
+ result = torch.stft(
+ signal,
+ n_fft=fft_size,
+ hop_length=hop_length,
+ win_length=window_length,
+ window=window,
+ normalized=normalize_flag,
+ center=False,
+ return_complex=True,
+ )
+ return result.transpose(0, -1) # transpose output so time axis is first to better match how torch lstm takes input
+
+
+def do_istft(
+ stft,
+ window_length=None,
+ hop_length=None,
+ fft_size=None,
+ normalize="default",
+ window_type="sqrt_hann",
+):
+ """
+ Wrap torch.istft and return time domain signal
+
+ :param stft: complex tensor of shape (n_frames, n_frequencies, n_channels) or (n_frames, n_frequencies)
+ :param window_length: int size of stft window
+ :param hop_length: int stride of stft window
+ :param fft_size: int geq to window_length, if None set to window_length
+ :param normalize: string for determining how to normalize stft output.
+ "window": divide the window by its sum. This will give amplitudes of components that match
+ time domain components, but small values could cause numerical issues
+ "default": default pytorch stft normalization divides by sqrt of window length
+ None: no normalization of stft outputs
+ :param window_type: string indicating the window type ("sqrt_hann", "hann", "blackman", "blackmanharris", "hamming")
+ :return: tensor of shape (n_samples, n_channels) or (n_samples)
+ """
+
+ window = get_window(window=window_type, window_length=window_length, device=stft.device)
+ if fft_size is None:
+ fft_size = window_length
+ window, normalize_flag = _normalize_options(window, normalize)
+ # Must have center=True to satisfy OLA constraints
+ signal = torch.istft(
+ stft.transpose(0, -1),
+ fft_size,
+ hop_length=hop_length,
+ win_length=window_length,
+ window=window,
+ center=True,
+ normalized=normalize_flag,
+ )
+ return signal.T # to be consistent with stft, we have sources in the last dimension
+
+
+def get_window(window="sqrt_hann", window_length=1024, device=None):
+ if window == "sqrt_hann":
+ return sqrt_hann(window_length, device=device)
+ elif window == "hann":
+ return torch.hann_window(window_length, periodic=True, device=device)
+ elif window in ["blackman", "blackmanharris", "hamming"]:
+ return torch.Tensor(signal.get_window(window, window_length), device=device)
+
+
+def sqrt_hann(window_length, device=None):
+ """Implement a sqrt-Hann window"""
+ return torch.sqrt(torch.hann_window(window_length, periodic=True, device=device))
+
+
+def get_num_fft_bins(fft_size):
+ return fft_size // 2 + 1
+
+
+def get_padded_stft_frames(signal_lens, window_length, hop_length, fft_size):
+ """
+ Compute stft frame lengths taking into account the padding used by our stft code
+
+ :param signal_lens: torch tensor of signal waveform lengths
+ :param window_length: scalar stft window length
+ :param hop_length: scalar stft hop length
+ :param fft_size: scalar stft fft size
+ :return: torch tensor of stft frame lengths
+ """
+ added_padding = 2 * int(fft_size // 2)
+ return torch.ceil((signal_lens + added_padding - window_length) / hop_length + 1)
+
+
+def stft_3dim(input, **kwargs):
+ """Apply STFT for 3-dim tensor one by one
+
+ Parameters
+ ----------
+ input: torch.Tensor, (n_batch, n_samples, n_src)
+ Input time-domain waveform
+
+ Returns
+ ----------
+ output: torch.Tensor, (n_batch, n_frame, n_freq, n_src):
+ Output STFT-domain spectrogram
+ """
+ assert input.ndim == 3, input.shape
+ output = [
+ do_stft(
+ input[..., i],
+ kwargs["window_length"],
+ kwargs["hop_length"],
+ kwargs["fft_size"],
+ kwargs["normalize"],
+ )
+ for i in range(input.shape[-1])
+ ]
+ output = torch.stack(output, dim=-1).movedim(2, 0)
+ return output
+
+
+def istft_4dim(input, **kwargs):
+ """Apply iSTFT for 4-dim tensor one by one
+
+ Parameters
+ ----------
+ input: torch.Tensorm (n_batch, n_frame, n_freq, n_src)
+ Input STFT-domain spectrogram
+
+ Returns
+ ----------
+ output: torch.Tensor, (n_batch, n_samples, n_src)
+ Output time-domain waveform
+ """
+ assert input.ndim == 4, input.shape
+ output = [
+ do_istft(
+ input[i],
+ kwargs["window_length"],
+ kwargs["hop_length"],
+ kwargs["fft_size"],
+ kwargs["normalize"],
+ )
+ for i in range(input.shape[0])
+ ]
+ output = torch.stack(output)
+ return output
diff --git a/utils/collate.py b/utils/collate.py
new file mode 100644
index 0000000..6c0662e
--- /dev/null
+++ b/utils/collate.py
@@ -0,0 +1,59 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+import collections.abc as container_abcs
+
+import torch
+import torch.nn.utils.rnn as rnn
+
+
+def collate_seq(batch):
+ elem = batch[0]
+ elem_type = type(elem)
+ if elem_type.__name__ == "ndarray":
+ lengths = torch.tensor([len(b) for b in batch])
+ pad_features = rnn.pad_sequence([torch.tensor(b, dtype=torch.float32) for b in batch], batch_first=True)
+ return pad_features, lengths
+
+ elif isinstance(elem, torch.Tensor):
+ lengths = torch.tensor([len(b) for b in batch])
+ pad_features = rnn.pad_sequence([b for b in batch], batch_first=True)
+ return pad_features, lengths
+
+ elif isinstance(elem, container_abcs.Sequence):
+ transposed = zip(*batch)
+ return [collate_seq(samples) for samples in transposed]
+
+ elif isinstance(elem, container_abcs.Mapping):
+ return {key: collate_seq([d[key] for d in batch]) for key in elem}
+
+ else:
+ # for other stuff just return it and do not collate
+ return [b for b in batch]
+
+
+def collate_seq_eras(batch):
+ elem = batch[0]
+ elem_type = type(elem)
+ if elem_type.__name__ == "ndarray":
+ raise RuntimeError("Input must be torch Tensor or Dict")
+
+ elif isinstance(elem, torch.Tensor):
+ lengths = torch.tensor([len(b) for b in batch])
+ pad_features = rnn.pad_sequence([b for b in batch], batch_first=True)
+ pad_features = pad_features.movedim(-1, 1)
+ pad_features = pad_features.reshape((-1,) + pad_features.shape[2:])
+ return pad_features, lengths
+
+ elif isinstance(elem, container_abcs.Sequence):
+ transposed = zip(*batch)
+ return [collate_seq_eras(samples) for samples in transposed]
+
+ elif isinstance(elem, container_abcs.Mapping):
+ return {key: collate_seq_eras([d[key] for d in batch]) for key in elem}
+
+ else:
+ # for other stuff just return it and do not collate
+ return [b for b in batch]
diff --git a/utils/config.py b/utils/config.py
new file mode 100644
index 0000000..b45540c
--- /dev/null
+++ b/utils/config.py
@@ -0,0 +1,43 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+import argparse
+
+import yaml
+
+
+def _val_to_argparse_kwargs(val):
+ if isinstance(val, str):
+ return {"type": strings_with_none, "default": val}
+ elif isinstance(val, bool):
+ return {"type": bool_string_to_bool, "default": val}
+ else:
+ return {"type": eval, "default": val}
+
+
+def strings_with_none(arg_str):
+ if arg_str.lower() in ["null", "none"]:
+ return None
+ else:
+ return str(arg_str)
+
+
+def bool_string_to_bool(bool_str):
+ if str(bool_str).lower() == "false":
+ return False
+ elif str(bool_str).lower() == "true":
+ return True
+ else:
+ raise argparse.ArgumentTypeError('For boolean args, use "True" or "False" strings, not {}'.format(bool_str))
+
+
+def yaml_to_parser(yaml_path):
+ default_hyperparams = yaml.safe_load(open(yaml_path))
+ parser = argparse.ArgumentParser()
+
+ for k, v in default_hyperparams.items():
+ argparse_kwargs = _val_to_argparse_kwargs(v)
+ parser.add_argument("--{}".format(k.replace("_", "-")), **argparse_kwargs)
+ return parser
diff --git a/utils/forward_convolutive_prediction.py b/utils/forward_convolutive_prediction.py
new file mode 100644
index 0000000..65ec99a
--- /dev/null
+++ b/utils/forward_convolutive_prediction.py
@@ -0,0 +1,116 @@
+# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL)
+#
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+import torch
+
+
+def forward_convolutive_prediction(
+ est: torch.Tensor,
+ ref: torch.Tensor,
+ past_taps: int,
+ future_taps: int,
+ eps: float = 1e-8,
+):
+ """Forward convolutive prediction (FCP) proposed in [1].
+
+ Parameters
+ ----------
+ est: torch.Tensor, (n_batch, n_frame, n_freq, n_src)
+ Signals to which the FCP is applied.
+ ref: torch.Tensor, (n_batch, n_frame, n_freq, n_chan)
+ Reference signals of the FCP.
+ past_taps: int
+ The number of the past taps in the FCP.
+ future_taps: int
+ The number of the future taps in the FCP.
+ eps: float
+ Stabilizer for matrix inverse.
+
+ Returns
+ ----------
+ output: torch.Tensor, (n_batch, n_frame, n_freq, n_chan, n_src)
+ Signals after applying the FCP.
+
+ References
+ ----------
+ [1]: Z.-Q Wang, G. Wichern, and J. Le Roux,
+ "Convolutive Prediction for Monaural Speech Dereverberation and Noisy-Reverberant Speaker Separation,"
+ IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 29, pp. 3476-3490, 2021.
+ """
+
+ n_chan = ref.shape[-1]
+ # stack past and future frames first
+ est_tilde = stack_past_and_future_taps(est, past_taps, future_taps)
+
+ # compute weighting term lambda
+ weight = compute_fcp_weight(ref)
+ weight = weight[..., None, None]
+
+ # compute FCP filter
+ # get auto-covariance matrix
+ auto_cov = torch.einsum("...tafn, ...tbfn -> ...tfnab", est_tilde, est_tilde.conj())
+ auto_cov = (auto_cov / weight).sum(dim=-5)
+ auto_cov = auto_cov.unsqueeze(-4).tile(1, 1, n_chan, 1, 1, 1)
+
+ # get cross-covariance matrix
+ cross_cov = torch.einsum("...tafn, ...tfc -> ...tfcna", est_tilde, ref.conj())
+ cross_cov = (cross_cov / weight).sum(dim=-5)
+
+ # compute relative RIR: (batch, past+1+future, freq, n_chan, n_src)
+ rir = torch.linalg.solve(auto_cov + eps, cross_cov)
+
+ # filter the estimate: (batch, frame, freq, n_chan, n_src)
+ output = torch.einsum("...fcna, ...tafn -> ...tfcn", rir.conj(), est_tilde)
+
+ return output
+
+
+def stack_past_and_future_taps(
+ input,
+ past_tap,
+ future_tap,
+):
+ """Function to stack the past and future frames of the input.
+
+ Parameters
+ ----------
+ input: torch.Tensor, (n_batch, n_frame, n_freq, n_src)
+ Signals to which the FCP is applied.
+ past_taps: int
+ Number of past taps in the FCP.
+ future_taps: int
+ Number of future taps in the FCP.
+
+ Returns
+ ----------
+ output: torch.Tensor, (n_batch, n_frame, past_tap+1+future_tap, n_freq, n_src)
+ A tensor that stacks the past and future frames of the input
+ """
+
+ T = input.shape[-3]
+ indices = torch.arange(0, past_tap + future_tap + 1).view(1, 1, past_tap + future_tap + 1).to(input.device)
+ padded_indices = torch.arange(T).view(1, T, 1).to(input.device) + indices
+ output = torch.nn.functional.pad(input, (0, 0, 0, 0, past_tap, future_tap))
+ output = output[:, padded_indices]
+ return output[:, 0]
+
+
+def compute_fcp_weight(input, eps=1e-4):
+ """Function to calculate the weighting term (lambda) in the FCP [1].
+
+ Parameters
+ ----------
+ input: torch.Tensor, (n_batch, n_frame, n_freq, n_chan)
+ Signal to compute the FCP weight.
+
+ Returns
+ ----------
+ weight: torch.Tensor, (n_batch, n_frame, n_freq, n_chan)
+ FCP weighting term with the same shape as the input.
+ """
+ power = (abs(input) ** 2).mean(dim=-1)
+ max_power = torch.max(power)
+ weight = power + eps * max_power
+ return weight.unsqueeze(-1)