From bf9b8328e758275fa5c6c444a3b7899f7c3f1611 Mon Sep 17 00:00:00 2001 From: Devesh Jha Date: Mon, 20 Nov 2023 10:11:50 -0500 Subject: [PATCH] Create release 1.0.0 --- .github/workflows/requirements-dev.txt | 5 + .github/workflows/static_checks.yaml | 76 +++ .gitignore | 168 +++++ .pre-commit-config.yaml | 63 ++ .reuse/dep5 | 5 + .vscode/README.md | 9 + .vscode/extensions.json | 7 + .vscode/settings.json | 32 + CONTRIBUTING.md | 10 + LICENSE.md | 661 ++++++++++++++++++++ QNTRPO/compute_dogleg_step.py | 254 ++++++++ QNTRPO/compute_steepest_descent_step.py | 94 +++ QNTRPO/conjugate_grad_solution.py | 33 + QNTRPO/conjugate_grad_solution_fullmat.py | 33 + QNTRPO/initialize_iterate.py | 19 + QNTRPO/main.py | 208 ++++++ QNTRPO/models.py | 51 ++ QNTRPO/quasinewton_approximation_hessian.py | 75 +++ QNTRPO/replay_memory.py | 24 + QNTRPO/running_stats.py | 86 +++ QNTRPO/trust_region_opt_step.py | 53 ++ QNTRPO/trust_region_opt_torch.py | 267 ++++++++ QNTRPO/trust_region_step.py | 381 +++++++++++ QNTRPO/utils_trpo.py | 85 +++ README.md | 90 +++ requirements.txt | 8 + 26 files changed, 2797 insertions(+) create mode 100644 .github/workflows/requirements-dev.txt create mode 100644 .github/workflows/static_checks.yaml create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 .reuse/dep5 create mode 100644 .vscode/README.md create mode 100644 .vscode/extensions.json create mode 100644 .vscode/settings.json create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE.md create mode 100644 QNTRPO/compute_dogleg_step.py create mode 100644 QNTRPO/compute_steepest_descent_step.py create mode 100644 QNTRPO/conjugate_grad_solution.py create mode 100644 QNTRPO/conjugate_grad_solution_fullmat.py create mode 100644 QNTRPO/initialize_iterate.py create mode 100644 QNTRPO/main.py create mode 100644 QNTRPO/models.py create mode 100644 QNTRPO/quasinewton_approximation_hessian.py create mode 100644 QNTRPO/replay_memory.py create mode 100644 QNTRPO/running_stats.py create mode 100644 QNTRPO/trust_region_opt_step.py create mode 100644 QNTRPO/trust_region_opt_torch.py create mode 100644 QNTRPO/trust_region_step.py create mode 100644 QNTRPO/utils_trpo.py create mode 100644 README.md create mode 100644 requirements.txt diff --git a/.github/workflows/requirements-dev.txt b/.github/workflows/requirements-dev.txt new file mode 100644 index 0000000..4f1aa6e --- /dev/null +++ b/.github/workflows/requirements-dev.txt @@ -0,0 +1,5 @@ +# Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +pre-commit diff --git a/.github/workflows/static_checks.yaml b/.github/workflows/static_checks.yaml new file mode 100644 index 0000000..7375f52 --- /dev/null +++ b/.github/workflows/static_checks.yaml @@ -0,0 +1,76 @@ +# Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +name: Static code checks + +on: # yamllint disable-line rule:truthy + pull_request: + push: + branches: + - '**' + tags-ignore: + - '**' + +env: + LICENSE: AGPL-3.0-or-later + FETCH_DEPTH: 1 + FULL_HISTORY: 0 + SKIP_WORD_PRESENCE_CHECK: 0 + +jobs: + static-code-check: + if: endsWith(github.event.repository.name, 'private') + + name: Run static code checks + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + + steps: + - name: Setup history + if: github.ref == 'refs/heads/oss' + run: | + echo "FETCH_DEPTH=0" >> $GITHUB_ENV + echo "FULL_HISTORY=1" >> $GITHUB_ENV + + - name: Setup version + if: github.ref == 'refs/heads/melco' + run: | + echo "SKIP_WORD_PRESENCE_CHECK=1" >> $GITHUB_ENV + + - name: Check out code + uses: actions/checkout@v3 + with: + fetch-depth: ${{ env.FETCH_DEPTH }} # '0' to check full history + + - name: Set up environment + run: git config user.email github-bot@merl.com + + - name: Set up python + uses: actions/setup-python@v4 + with: + python-version: 3 + cache: 'pip' + cache-dependency-path: '.github/workflows/requirements-dev.txt' + + - name: Install python packages + run: pip install -r .github/workflows/requirements-dev.txt + + - name: Ensure lint and pre-commit steps have been run + uses: pre-commit/action@v3.0.0 + + - name: Check files + uses: merl-oss-private/merl-file-check-action@v1 + with: + license: ${{ env.LICENSE }} + full-history: ${{ env.FULL_HISTORY }} # If true, use fetch-depth 0 above + skip-word-presence-check: ${{ env.SKIP_WORD_PRESENCE_CHECK }} + + - name: Check license compatibility + if: github.ref != 'refs/heads/melco' + uses: merl-oss-private/merl_license_compatibility_checker@v1 + with: + input-filename: requirements.txt + license: ${{ env.LICENSE }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d4cdd2c --- /dev/null +++ b/.gitignore @@ -0,0 +1,168 @@ +# Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL). +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +# Python .gitignore from https://github.com/github/gitignore/blob/main/Python.gitignore +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Custom ignores +.DS_Store diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..576dfcd --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,63 @@ +# Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later +# +# Pre-commit configuration. See https://pre-commit.com + +default_language_version: + python: python3 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-yaml + - id: check-added-large-files + args: ['--maxkb=1000'] + + - repo: https://gitlab.com/bmares/check-json5 + rev: v1.0.0 + hooks: + - id: check-json5 + + - repo: https://github.com/homebysix/pre-commit-macadmin + rev: v1.12.3 + hooks: + - id: check-git-config-email + args: ['--domains', 'merl.com'] + + - repo: https://github.com/psf/black + rev: 22.12.0 + hooks: + - id: black + args: + - --line-length=120 + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files", "--line-length", "120", "--skip-gitignore"] + + # Uncomment to use pyupgrade (https://github.com/asottile/pyupgrade) to automatically upgrade syntax for newer python + # - repo: https://github.com/asottile/pyupgrade + # rev: v3.3.1 + # hooks: + # - id: pyupgrade + + # To stop flake8 error from causing a failure, use --exit-zero. By default, pre-commit will not show the warnings, + # so use verbose: true to see them. + - repo: https://github.com/pycqa/flake8 + rev: 5.0.4 + hooks: + - id: flake8 + # Black compatibility, Eradicate options + args: ["--max-line-length=120", "--extend-ignore=E203", "--eradicate-whitelist-extend", "eradicate:\\s*no", + "--exit-zero"] + verbose: true + additional_dependencies: [ + # https://github.com/myint/eradicate, https://github.com/wemake-services/flake8-eradicate + "flake8-eradicate" + ] diff --git a/.reuse/dep5 b/.reuse/dep5 new file mode 100644 index 0000000..288f2c7 --- /dev/null +++ b/.reuse/dep5 @@ -0,0 +1,5 @@ +Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ + +Files: .vscode/* +Copyright: 2023 Mitsubishi Electric Research Laboratories (MERL) +License: AGPL-3.0-or-later diff --git a/.vscode/README.md b/.vscode/README.md new file mode 100644 index 0000000..4eb80d7 --- /dev/null +++ b/.vscode/README.md @@ -0,0 +1,9 @@ +# VS Code recommended extensions and settings + +These files provide recommended extensions and workspace settings for VS Code for python development. The recommended extensions are: + +* [Python](https://marketplace.visualstudio.com/items?itemName=ms-python.python"): Official python extension from Microsoft +* [Python Type Hint](https://marketplace.visualstudio.com/items?itemName=njqdev.vscode-python-typehint): Type hint completion for Python +* [autoDocstring](https://marketplace.visualstudio.com/items?itemName=njpwerner.autodocstring): Generates python docstrings automatically + +If these extensions are not already globally installed, they will be recommended to you for installation when you open the project in VS Code. diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..2d42587 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,7 @@ +{ + "recommendations": [ + "ms-python.python", + "njqdev.vscode-python-typehint", + "njpwerner.autodocstring" + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..b5520a9 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,32 @@ +{ + "editor.rulers": [ + 120 + ], + "[python]": { + "editor.tabSize": 4 + }, + "[markdown]": { + "editor.wordWrap": "bounded", + "editor.wordWrapColumn": 120 + }, + "files.eol": "\n", + "files.insertFinalNewline": true, + "files.trimFinalNewlines": true, + "files.trimTrailingWhitespace": true, + "editor.formatOnSave": true, + "python.formatting.provider": "black", + "python.formatting.blackArgs": [ + "--line-length=120" + ], + "python.linting.flake8Enabled": true, + "python.linting.enabled": true, + "python.linting.flake8Args": [ + "--max-line-length=120", + "--extend-ignore=E203" + ], + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..08901d5 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,10 @@ + + +# Contributing + +Sorry, but we do not currently accept contributions in the form of pull requests to this repository. However, you are +welcome to post issues (bug reports, feature requests, questions, etc). diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..b54b4da --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,661 @@ + +### GNU AFFERO GENERAL PUBLIC LICENSE + +Version 3, 19 November 2007 + +Copyright (C) 2007 Free Software Foundation, Inc. + + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + +### Preamble + +The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + +The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains +free software for all its users. + +When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + +Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + +A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + +The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + +An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing +under this license. + +The precise terms and conditions for copying, distribution and +modification follow. + +### TERMS AND CONDITIONS + +#### 0. Definitions. + +"This License" refers to version 3 of the GNU Affero General Public +License. + +"Copyright" also means copyright-like laws that apply to other kinds +of works, such as semiconductor masks. + +"The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + +To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of +an exact copy. The resulting work is called a "modified version" of +the earlier work or a work "based on" the earlier work. + +A "covered work" means either the unmodified Program or a work based +on the Program. + +To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + +To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user +through a computer network, with no transfer of a copy, is not +conveying. + +An interactive user interface displays "Appropriate Legal Notices" to +the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + +#### 1. Source Code. + +The "source code" for a work means the preferred form of the work for +making modifications to it. "Object code" means any non-source form of +a work. + +A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + +The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + +The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + +The Corresponding Source need not include anything that users can +regenerate automatically from other parts of the Corresponding Source. + +The Corresponding Source for a work in source code form is that same +work. + +#### 2. Basic Permissions. + +All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + +You may make, run and propagate covered works that you do not convey, +without conditions so long as your license otherwise remains in force. +You may convey covered works to others for the sole purpose of having +them make modifications exclusively for you, or provide you with +facilities for running those works, provided that you comply with the +terms of this License in conveying all material for which you do not +control copyright. Those thus making or running the covered works for +you must do so exclusively on your behalf, under your direction and +control, on terms that prohibit them from making any copies of your +copyrighted material outside their relationship with you. + +Conveying under any other circumstances is permitted solely under the +conditions stated below. Sublicensing is not allowed; section 10 makes +it unnecessary. + +#### 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + +No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + +When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such +circumvention is effected by exercising rights under this License with +respect to the covered work, and you disclaim any intention to limit +operation or modification of the work as a means of enforcing, against +the work's users, your or third parties' legal rights to forbid +circumvention of technological measures. + +#### 4. Conveying Verbatim Copies. + +You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + +You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + +#### 5. Conveying Modified Source Versions. + +You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these +conditions: + +- a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. +- b) The work must carry prominent notices stating that it is + released under this License and any conditions added under + section 7. This requirement modifies the requirement in section 4 + to "keep intact all notices". +- c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. +- d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + +A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + +#### 6. Conveying Non-Source Forms. + +You may convey a covered work in object code form under the terms of +sections 4 and 5, provided that you also convey the machine-readable +Corresponding Source under the terms of this License, in one of these +ways: + +- a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. +- b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the Corresponding + Source from a network server at no charge. +- c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. +- d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. +- e) Convey the object code using peer-to-peer transmission, + provided you inform other peers where the object code and + Corresponding Source of the work are being offered to the general + public at no charge under subsection 6d. + +A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + +A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, +family, or household purposes, or (2) anything designed or sold for +incorporation into a dwelling. In determining whether a product is a +consumer product, doubtful cases shall be resolved in favor of +coverage. For a particular product received by a particular user, +"normally used" refers to a typical or common use of that class of +product, regardless of the status of the particular user or of the way +in which the particular user actually uses, or expects or is expected +to use, the product. A product is a consumer product regardless of +whether the product has substantial commercial, industrial or +non-consumer uses, unless such uses represent the only significant +mode of use of the product. + +"Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to +install and execute modified versions of a covered work in that User +Product from a modified version of its Corresponding Source. The +information must suffice to ensure that the continued functioning of +the modified object code is in no case prevented or interfered with +solely because modification has been made. + +If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + +The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or +updates for a work that has been modified or installed by the +recipient, or for the User Product in which it has been modified or +installed. Access to a network may be denied when the modification +itself materially and adversely affects the operation of the network +or violates the rules and protocols for communication across the +network. + +Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + +#### 7. Additional Terms. + +"Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + +When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + +Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders +of that material) supplement the terms of this License with terms: + +- a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or +- b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or +- c) Prohibiting misrepresentation of the origin of that material, + or requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or +- d) Limiting the use for publicity purposes of names of licensors + or authors of the material; or +- e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or +- f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions + of it) with contractual assumptions of liability to the recipient, + for any liability that these contractual assumptions directly + impose on those licensors and authors. + +All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + +If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + +Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; the +above requirements apply either way. + +#### 8. Termination. + +You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + +However, if you cease all violation of this License, then your license +from a particular copyright holder is reinstated (a) provisionally, +unless and until the copyright holder explicitly and finally +terminates your license, and (b) permanently, if the copyright holder +fails to notify you of the violation by some reasonable means prior to +60 days after the cessation. + +Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + +Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + +#### 9. Acceptance Not Required for Having Copies. + +You are not required to accept this License in order to receive or run +a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + +#### 10. Automatic Licensing of Downstream Recipients. + +Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + +An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + +You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + +#### 11. Patents. + +A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + +A contributor's "essential patent claims" are all patent claims owned +or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + +Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + +In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + +If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + +If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + +A patent license is "discriminatory" if it does not include within the +scope of its coverage, prohibits the exercise of, or is conditioned on +the non-exercise of one or more of the rights that are specifically +granted under this License. You may not convey a covered work if you +are a party to an arrangement with a third party that is in the +business of distributing software, under which you make payment to the +third party based on the extent of your activity of conveying the +work, and under which the third party grants, to any of the parties +who would receive the covered work from you, a discriminatory patent +license (a) in connection with copies of the covered work conveyed by +you (or copies made from those copies), or (b) primarily for and in +connection with specific products or compilations that contain the +covered work, unless you entered into that arrangement, or that patent +license was granted, prior to 28 March 2007. + +Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + +#### 12. No Surrender of Others' Freedom. + +If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under +this License and any other pertinent obligations, then as a +consequence you may not convey it at all. For example, if you agree to +terms that obligate you to collect a royalty for further conveying +from those to whom you convey the Program, the only way you could +satisfy both those terms and this License would be to refrain entirely +from conveying the Program. + +#### 13. Remote Network Interaction; Use with the GNU General Public License. + +Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your +version supports such interaction) an opportunity to receive the +Corresponding Source of your version by providing access to the +Corresponding Source from a network server at no charge, through some +standard or customary means of facilitating copying of software. This +Corresponding Source shall include the Corresponding Source for any +work covered by version 3 of the GNU General Public License that is +incorporated pursuant to the following paragraph. + +Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + +#### 14. Revised Versions of this License. + +The Free Software Foundation may publish revised and/or new versions +of the GNU Affero General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + +Each version is given a distinguishing version number. If the Program +specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever +published by the Free Software Foundation. + +If the Program specifies that a proxy can decide which future versions +of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + +Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + +#### 15. Disclaimer of Warranty. + +THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT +WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND +PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE +DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR +CORRECTION. + +#### 16. Limitation of Liability. + +IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR +CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, +INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES +ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT +NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR +LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM +TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER +PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +#### 17. Interpretation of Sections 15 and 16. + +If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + +END OF TERMS AND CONDITIONS + +### How to Apply These Terms to Your New Programs + +If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these +terms. + +To do so, attach the following notices to the program. It is safest to +attach them to the start of each source file to most effectively state +the exclusion of warranty; and each file should have at least the +"copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as + published by the Free Software Foundation, either version 3 of the + License, or (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper +mail. + +If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for +the specific requirements. + +You should also get your employer (if you work as a programmer) or +school, if any, to sign a "copyright disclaimer" for the program, if +necessary. For more information on this, and how to apply and follow +the GNU AGPL, see . diff --git a/QNTRPO/compute_dogleg_step.py b/QNTRPO/compute_dogleg_step.py new file mode 100644 index 0000000..8168a12 --- /dev/null +++ b/QNTRPO/compute_dogleg_step.py @@ -0,0 +1,254 @@ +# Copyright (C) 2019, 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +"""Python script to compute the trust region step """ + +import logging + +import numpy as np +import torch +from torch.autograd import Variable + +logging.basicConfig(level=logging.DEBUG) + +from conjugate_grad_solution import conjugate_gradient +from conjugate_grad_solution_fullmat import conjugate_gradient_fullmat + +logger = logging.getLogger(__name__) + + +class trpo_step(object): + def __init__(self, n_var, type_option, model, f, get_kl, damping): + + self._n_var = n_var + self.type = type_option + self.model = model + self.f = f + self.get_kl = get_kl + self.damping = damping + + # self.conjugate_gradient_solver=conjugate_gradients + + def compute_step(self, iterate, hessian, delta): + + if self.type == 1: + ## Implement Dogleg method + + """ + + get the Newton direction, check to see if it is within the trust-region + original QP: min g^T d + 0.5*d^T*H*d s.t. d^T A d <= delta^2 + transformed QP: min ghat^T dhat + 0.5*dhat^T*Hhat*dhat s.t. dhat^T dhat <= delta^2 + where A = L*L^T, dhat = L^T*d, ghat = L^{-1}*g, Hhat = L^{-1}*H*L^{-T} + newton step on transformed: dxNhat = -Hhat grad_hat = -L^T*H^{-1}*g = L^T*dxN + where dxN is the original newton step + So checking if dxNhat^T*dxNhat <= delta^2 is equivalent to dxN^T*A*dxN <= delta^2 + + """ + ###### CG + + a = torch.cat([grad.view(-1) for grad in iterate.g]).data + + # Hessian_numpy=hessian.hessian.numpy() + + # cond_hessian=np.linalg.cond(Hessian_numpy) + + # print ("cond_hessian", cond_hessian) + + a1 = -1 * a + # print ("Size of a1",a.size()) + dxN, flag_cg = conjugate_gradient_fullmat(hessian, a1, 100, 1e-6) # 2*self._n_var + + print("dxN_size", torch.norm(dxN), "flag_cg", flag_cg) + if flag_cg == 0: + + b = torch.cat([d.view(-1) for d in dxN]).data + + AdxN = self.Fvp(b) # np.matmul(iterate.A,dxN) ###### FVP + + dxN_size = torch.sqrt(torch.dot(dxN, AdxN)) ### change the variable A + + if dxN_size <= delta: + dx = dxN + dx_size = dxN_size + flag = "N" + return dx, dx_size, flag + + """ + Get the steepest descent direction taking A into account + original QP: min g^T d + 0.5*d^T*H*d s.t. d^T A d <= delta^2 + transformed QP: min ghat^T dhat + 0.5*dhat^T*Hhat*dhat s.t. dhat^T dhat <= delta^2 + where A = L*L^T, dhat = L^T*d, ghat = L^{-1}*g, Hhat = L^{-1}*H*L^{-T} + dxShat = -ghat + We want to find the step size alphahat that minimizes transformed QP + alphahat = (ghat^T*ghat)/(ghat^T*Hhat*ghat) = (g^T*A^{-1}*g)/(g^T*A^{-1}*H*A^{-1}*g) + norm of the transformed step size is alphahat*norm(dxShat) = alphahat*sqrt(g^T*A^{-1}*g) + If this step size is >= delta then the transformed step is -delta/(sqrt(g^T*A^{-1}*g))*ghat + The original step -delta/(sqrt(g^T*A^{-1}*g))*Ainvg + + """ + ##### Do CG + # Ainvg, flag_Ag =self.conjugate_gradient(iterate.A,iterate.g,1e-6,2*self._n_var,1) #self.conjugate_gradient_solver(A,g,2*self._n_var) + + # print("size of gradient", a.size()) + # a1=torch.cat([grad.view(-1) for grad in iterate.g]).data + Ainvg, flag_Ag = conjugate_gradient(self.Fvp, a, 100, 1e-6) + + # flag_Ag=flag_Ag.numpy() + + # print ("vector product",np.vdot(Ainvg,iterate.g),"flag_Ag", flag_Ag) + alpha_hat = 0 + ghat_nrm = 0 + + if flag_Ag == 0: + # AgBAg = np.vdot(Ainvg,hessian.hessvec(Ainvg)) + AgBAg = torch.dot(Ainvg, hessian.hessvec(Ainvg)) + alpha_hat = torch.dot(Ainvg, a) / AgBAg + + ############### + # x1=torch.dot(Ainvg,-a)*alpha_hat + + # x2=AgBAg*0.5*alpha_hat**2 + + # x=-x1+x2 + + # print("Decrease in function", x) + ##################################### + + ghat_nrm = torch.sqrt(torch.dot(Ainvg, a)) + + dxShat_size = alpha_hat * ghat_nrm + + print("dxShat_size", dxShat_size, "delta", delta) + + if flag_cg > 0 or dxShat_size >= delta: + + dx = -delta / ghat_nrm * Ainvg + dx_size = delta + flag = "S" + return dx, dx_size, flag + + if flag_cg > 0 and dxShat_size < delta: + dx = -alpha_hat * Ainvg + dx_size = dxShat_size + flag = "S0" + return dx, dx_size, flag + + ## if failed to compute the Newton Step or the steepest descent direction, resort to this + + if flag_cg > 0 or flag_Ag > 0: + dxS = -torch.dot(a, a) / torch.dot(a, hessian.hessvec(a)) * a + + b_flat = torch.cat([d.view(-1) for d in dxS]).data + # b_flat=b.view(-1) + + AinvdxS = self.Fvp(b_flat) + + ###---------------------------------- + dxS_size = torch.sqrt(torch.dot(dxS, AinvdxS)) #######FVPnp.matmul(iterate.A,dxS + dx = delta / dxS_size * dxS + dx_size = delta + flag = "s" + return dx, dx_size, flag + + """ + get the dogleg step taking A into account + original QP: min g^T d + 0.5*d^T*H*d s.t. d^T A d <= delta^2 + transformed QP: min ghat^T dhat + 0.5*dhat^T*Hhat*dhat s.t. dhat^T dhat <= delta^2 + where A = L*L^T, dhat = L^T*d, ghat = L^{-1}*g, Hhat = L^{-1}*H*L^{-T} + dxShat = -alphahat*ghat + dxNhat = L^T*dxN + Find alpha s.t. ||dxShat + alpha*(dxNhat - dxShat)||^2 = delta^2 + equiv. to solving ||L^T(L^{-T}*dxShat + alpha*(dxN - L^{-T}*dxShat)||^2 = delta^2 + equiv. to solving ||-alphahat*Ainvg + alpha*(dxN + alphahat*Ainvg)||^2_A = delta^2 + form the quadratic equation + + + """ + + # Ainvg=np.reshape(Ainvg,[len(Ainvg),1]) + + dxNAinvg = dxN + alpha_hat * Ainvg + + # print ("dxNAinvg size", dxN.shape) + # Fvp=iterate.A + # a=torch.from_numpy(dxNAinvg) + # a=a.view(-1) + + # print (a.size()) + + atimesdxNAing = self.Fvp(dxNAinvg) + # atimesdxNAing=atimesdxNAing + + a_quad = torch.dot(dxNAinvg, atimesdxNAing) ####### FVPnp.matmul(iterate.A,dxNAinvg) + b_quad = -2 * alpha_hat * torch.dot(a, dxNAinvg) + c_quad = alpha_hat**2 * torch.dot(Ainvg, a) - delta**2 + + print("a_quad", "b_quad", "c_quad", a_quad, b_quad, c_quad) + + ## Newton step and steepest descent are parallel + + if a_quad <= 1e-6: + + dx = -delta / ghat_nrm * Ainvg + dx_size = delta + + alpha = np.roots([a_quad, b_quad, c_quad]) + + alpha_opt = np.max(alpha) + + print("alpha", alpha_opt) + + dx = -alpha_hat * Ainvg + alpha_opt * dxNAinvg + + # dx_torch=torch.from_numpy(dx) + # dx_torch=dx_torch.view(-1) + + # Fvp=iterate.A + Atimesdx = self.Fvp(dx) + + # Atimesdxtorch=Atimesdxtorch.numpy() + dx_size = torch.sqrt(torch.dot(dx, Atimesdx)) ######## FVP np.matmul(iterate.A,dx) + + if alpha_opt < 0: + print("alpha_opt", alpha_opt) + + logging.debug("Error in computing the dogleg step") + + dx = [] + + dx_size = 0 + + if abs(dx_size - delta) >= 1e-2: + + print("distance", abs(dx_size - delta)) + logging.debug("Error in computing Dogleg Step") + + flag = "D" + return dx, dx_size, flag + + def Fvp(self, v): + + """ + function to compute the fisher vector product , i.e., Ainv*vec + where Ainv is the inverse of the FIM + + """ + model = self.model + get_kl = self.get_kl + damping = self.damping + + kl = get_kl() + kl = kl.mean() + + grads = torch.autograd.grad(kl, model.parameters(), create_graph=True) + flat_grad_kl = torch.cat([grad.view(-1) for grad in grads]) + + kl_v = (flat_grad_kl * Variable(v)).sum() + + grads = torch.autograd.grad(kl_v, model.parameters()) + + flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).data + + return flat_grad_grad_kl + v * damping diff --git a/QNTRPO/compute_steepest_descent_step.py b/QNTRPO/compute_steepest_descent_step.py new file mode 100644 index 0000000..26785da --- /dev/null +++ b/QNTRPO/compute_steepest_descent_step.py @@ -0,0 +1,94 @@ +# Copyright (C) 2019, 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +"""Python script to compute the trust region step """ + +import logging + +import numpy as np +import torch +from torch.autograd import Variable + +logging.basicConfig(level=logging.DEBUG) + +from conjugate_grad_solution import conjugate_gradient + +logger = logging.getLogger(__name__) + + +class trpo_step(object): + def __init__(self, n_var, type_option, model, f, get_kl, damping): + + self._n_var = n_var + self.type = type_option + self.model = model + self.f = f + self.get_kl = get_kl + self.damping = damping + + # self.conjugate_gradient_solver=conjugate_gradients + + def compute_step(self, iterate, hessian, delta): + + if self.type == 0: + + ## Compute scaled steepest descent #### DO CG + # [Ainvg,flag_Ag] = self.conjugate_gradient(iterate.A,iterate.g,1e-6,2*self._n_var,1) + a = torch.cat([grad.view(-1) for grad in iterate.g]).data + + Ainvg, flag_Ag = conjugate_gradient(self.Fvp, -a, 100, 1e-6) # 2*self._n_var + # Ainvg=Ainvg.numpy() + # flag_Ag=flag_Ag.numpy() + + print("flag_Ag", flag_Ag) + + if flag_Ag == 0: + print("In the scaling loop") + + vecpdt = torch.dot(Ainvg, -a) + + scaledx = delta / torch.sqrt(vecpdt) + + print("Scale dx", scaledx) + dx = scaledx * Ainvg # scaledx*Ainvg.numpy() + else: + ##### Do FVP + Fvp = iterate.A + b = torch.from_numpy(iterate.g) + b = b.view(-1) + Ag = Fvp(b) + Ag = Ag.numpy() + # Ag = np.matmul(iterate.A,iterate.g) + scaledx = delta / np.sqrt(np.vdot(iterate.g, Ag)) + dx = -scaledx * iterate.g + + dx_size = delta + flag = "S" + + return dx, dx_size, flag + + def Fvp(self, v): + + """ + function to compute the fisher vector product , i.e., Ainv*vec + where Ainv is the inverse of the FIM + + """ + model = self.model + get_kl = self.get_kl + damping = self.damping + + kl = get_kl() + kl = kl.mean() + + grads = torch.autograd.grad(kl, model.parameters(), create_graph=True) + flat_grad_kl = torch.cat([grad.view(-1) for grad in grads]) + + kl_v = (flat_grad_kl * Variable(v)).sum() + + grads = torch.autograd.grad(kl_v, model.parameters()) + + flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).data + + return flat_grad_grad_kl + v * damping diff --git a/QNTRPO/conjugate_grad_solution.py b/QNTRPO/conjugate_grad_solution.py new file mode 100644 index 0000000..b50a168 --- /dev/null +++ b/QNTRPO/conjugate_grad_solution.py @@ -0,0 +1,33 @@ +# Copyright (C) 2019, 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import numpy as np +import torch + + +def conjugate_gradient(Avp, b, nsteps, residual_tol=1e-6): + """ + Returns F^(-1)b where F is the Hessian of the KL divergence + """ + p = b.clone().data + r = b.clone().data + x = np.zeros_like(b.data.cpu().numpy()) + + x = torch.from_numpy(x) + + flag_cg = 1 + rdotr = torch.dot(r, r) # r.double().dot(r.double()) + for _ in range(nsteps): + z = Avp(p) # self.hessian_vector_product(Variable(p)).squeeze(0) + v = rdotr / torch.dot(p, z) # p.double().dot(z.double()) + x += v * p # (p.cpu().numpy()) + r -= v * z + newrdotr = torch.dot(r, r) # r.double().dot(r.double()) + mu = newrdotr / rdotr + p = r + mu * p + rdotr = newrdotr + if rdotr < residual_tol: + flag_cg = 0 + break + return x, flag_cg diff --git a/QNTRPO/conjugate_grad_solution_fullmat.py b/QNTRPO/conjugate_grad_solution_fullmat.py new file mode 100644 index 0000000..8c800d5 --- /dev/null +++ b/QNTRPO/conjugate_grad_solution_fullmat.py @@ -0,0 +1,33 @@ +# Copyright (C) 2019, 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import numpy as np +import torch + + +def conjugate_gradient_fullmat(Avp, b, nsteps, residual_tol=1e-6): + """ + Returns F^(-1)b where F is the Hessian of the KL divergence + """ + p = b.clone().data + r = b.clone().data + x = np.zeros_like(b.data.cpu().numpy()) + + x = torch.from_numpy(x) + + flag_cg = 1 + rdotr = torch.dot(r, r) # r.double().dot(r.double()) + for _ in range(nsteps): + z = Avp.hessvec(p) # self.hessian_vector_product(Variable(p)).squeeze(0) + v = rdotr / torch.dot(p, z) # p.double().dot(z.double()) + x += v * p # (p.cpu().numpy()) + r -= v * z + newrdotr = torch.dot(r, r) # r.double().dot(r.double()) + mu = newrdotr / rdotr + p = r + mu * p + rdotr = newrdotr + if rdotr < residual_tol: + flag_cg = 0 + break + return x, flag_cg diff --git a/QNTRPO/initialize_iterate.py b/QNTRPO/initialize_iterate.py new file mode 100644 index 0000000..16a7832 --- /dev/null +++ b/QNTRPO/initialize_iterate.py @@ -0,0 +1,19 @@ +# Copyright (C) 2019, 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import numpy as np +import torch + + +class initialize_solution_iterate(object): + def __init__(self, x, f, g, A): + + # self._function_class=function_class + + self.x = x # np.reshape(x,[len(x),1]) + self.f = f # function_class.fun(self.x) + self.g = g # np.reshape(g,[len(g),1])#function_class.gradf(self.x) + self.A = A # function_class.trust_region_hess(self.x) + + self.error = torch.norm(self.g) diff --git a/QNTRPO/main.py b/QNTRPO/main.py new file mode 100644 index 0000000..e077548 --- /dev/null +++ b/QNTRPO/main.py @@ -0,0 +1,208 @@ +# Copyright (C) 2019, 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import argparse +from itertools import count + +import gym +import scipy.optimize +import torch +from models import * +from replay_memory import Memory +from running_stats import Zfilter +from torch.autograd import Variable +from trust_region_opt_step import update_params_trust_region_step +from utils_trpo import * + +torch.utils.backcompat.broadcast_warning.enabled = True +torch.utils.backcompat.keepdim_warning.enabled = True + +torch.set_default_tensor_type("torch.DoubleTensor") + +parser = argparse.ArgumentParser(description="PyTorch actor-critic example") +parser.add_argument("--discount_factor", type=float, default=0.99, metavar="G", help="discount factor (default: 0.995)") +parser.add_argument("--env-name", default="Pendulum-v0", metavar="G", help="name of the environment to run") +parser.add_argument("--tau", type=float, default=0.97, metavar="G", help="gae (default: 0.97)") +parser.add_argument( + "--l2-reg", type=float, default=1e-3, metavar="G", help="l2 regularization regression (default: 1e-3)" +) +parser.add_argument("--max-kl", type=float, default=1e-2, metavar="G", help="max kl value (default: 1e-2)") +parser.add_argument("--damping", type=float, default=1e-1, metavar="G", help="damping (default: 1e-1)") +parser.add_argument("--seed", type=int, default=1234, metavar="N", help="random seed (default: 1)") +parser.add_argument("--batch-size", type=int, default=15000, metavar="N", help="random seed (default: 1)") +parser.add_argument("--render", action="store_true", help="render the environment") +parser.add_argument( + "--log-interval", type=int, default=1, metavar="N", help="interval between training status logs (default: 10)" +) +# parser.add_argument('--optimization_type', type=int, default=1, metavar='N', +# help='interval between training status logs (default: 10)') + +# parameters for second order optimization + +args = parser.parse_args() + +env = gym.make(args.env_name) + +num_inputs = env.observation_space.shape[0] +num_actions = env.action_space.shape[0] + +env.seed(args.seed) +torch.manual_seed(args.seed) + +policy_net = Policy(num_inputs, num_actions) +value_net = Value(num_inputs) + +filename = "reward_results_rs_" + str(args.seed) + "_" + args.env_name + ".txt" + +file = open(filename, "w") + + +def select_action(state): + state = torch.from_numpy(state).unsqueeze(0) + action_mean, _, action_std = policy_net(Variable(state)) + action = torch.normal(action_mean, action_std) + return action + + +def update_params(batch): + + rewards = torch.Tensor(batch.reward) + + masks = torch.Tensor(batch.mask) + + actions = torch.Tensor(np.concatenate(batch.action, 0)) + states = torch.Tensor(batch.state) + values = value_net(Variable(states)) + + returns = torch.Tensor(actions.size(0), 1) + deltas = torch.Tensor(actions.size(0), 1) + advantages = torch.Tensor(actions.size(0), 1) + + prev_return = 0 + prev_value = 0 + prev_advantage = 0 + + for i in reversed(range(rewards.size(0))): + returns[i] = rewards[i] + args.discount_factor * prev_return * masks[i] + deltas[i] = rewards[i] + args.discount_factor * prev_value * masks[i] - values.data[i] + advantages[i] = deltas[i] + args.discount_factor * args.tau * prev_advantage * masks[i] + + prev_return = returns[i, 0] + prev_value = values.data[i, 0] + prev_advantage = advantages[i, 0] + + targets = Variable(returns) + + def get_value_loss(flat_params): + set_flat_params_to(value_net, torch.Tensor(flat_params)) + for param in value_net.parameters(): + if param.grad is not None: + param.grad.data.fill_(0) + + values_ = value_net(Variable(states)) + + value_loss = (values_ - targets).pow(2).mean() + + # weight decay + for param in value_net.parameters(): + value_loss += param.pow(2).sum() * args.l2_reg + value_loss.backward() + return (value_loss.data.double().numpy(), get_flat_grad_from(value_net).data.double().numpy()) + + flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b( + get_value_loss, get_flat_params_from(value_net).double().numpy(), maxiter=25 + ) + set_flat_params_to(value_net, torch.Tensor(flat_params)) + + advantages = (advantages - advantages.mean()) / advantages.std() + + action_means, action_log_stds, action_stds = policy_net(Variable(states)) + fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone() + + def get_loss(volatile=False): + if volatile: + with torch.no_grad(): + action_means, action_log_stds, action_stds = policy_net(Variable(states)) + else: + action_means, action_log_stds, action_stds = policy_net(Variable(states)) + + log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds) + action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob)) + return action_loss.mean() + + def get_kl(): + + mean1, log_std1, std1 = policy_net(Variable(states)) + + mean0 = Variable(mean1.data) + log_std0 = Variable(log_std1.data) + std0 = Variable(std1.data) + kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5 + return kl.sum(1, keepdim=True) + + curr_params = get_flat_params_from(policy_net) + ### Update the TRPO steps using the first-order method. + ##------------------------------------------------------------------- + # trpo_step(policy_net, get_loss, get_kl, args.max_kl, args.damping) + ## ------------------------------------------------------------------ + ### Update the TRPO step using the second-order method + ##---------------------------------------------------------------------------- + # sec_order_trpo_step(policy_net, get_loss, get_kl, args.max_kl, args.damping) + update_params_trust_region_step(policy_net, get_loss, get_kl, args.max_kl, args.damping) + ###------------------------------------------------------------------------- + + new_params = get_flat_params_from(policy_net) + + print("L2 norm of change in model parameters......", np.linalg.norm(curr_params.numpy() - new_params.numpy())) + + +running_stats = Zfilter((num_inputs,), clip=5) +running_reward = Zfilter((1,), demean=False, clip=10) + + +for i_episode in count(1): + memory = Memory() + + num_steps = 0 + reward_batch = 0 + num_episodes = 0 + while num_steps < args.batch_size: + state = env.reset() + state = running_stats(state) + + reward_sum = 0 + for t in range(2000): + action = select_action(state) + + action = action.data[0].numpy() + next_state, reward, done, _ = env.step(action) + reward_sum += reward + + next_state = running_stats(next_state) + + mask = 1 + if done: + mask = 0 + + memory.push(state, np.array([action]), mask, next_state, reward) + + # if args.render: + # if i_episode>10: + # env.render() + if done: + break + + state = next_state + num_steps += t - 1 + num_episodes += 1 + reward_batch += reward_sum + + reward_batch /= num_episodes + # print (i_episode,reward_batch) + file.write("{},{}\n".format(int(i_episode), float(reward_batch))) + batch = memory.sample() + update_params(batch) + + if i_episode % args.log_interval == 0: + print("Episode {}\tLast reward: {}\tAverage reward {:.2f}".format(i_episode, reward_sum, reward_batch)) diff --git a/QNTRPO/models.py b/QNTRPO/models.py new file mode 100644 index 0000000..0c24bf9 --- /dev/null +++ b/QNTRPO/models.py @@ -0,0 +1,51 @@ +# Copyright (C) 2019, 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import torch +import torch.autograd as autograd +import torch.nn as nn + + +class Policy(nn.Module): + def __init__(self, num_inputs, num_outputs): + super(Policy, self).__init__() + self.affine1 = nn.Linear(num_inputs, 64) + self.affine2 = nn.Linear(64, 64) + + self.action_mean = nn.Linear(64, num_outputs) + self.action_mean.weight.data.mul_(0.1) + self.action_mean.bias.data.mul_(0.0) + + self.action_log_std = nn.Parameter(torch.zeros(1, num_outputs)) + + self.saved_actions = [] + self.rewards = [] + self.final_value = 0 + + def forward(self, x): + x = torch.tanh(self.affine1(x)) + x = torch.tanh(self.affine2(x)) + + action_mean = self.action_mean(x) + action_log_std = self.action_log_std.expand_as(action_mean) + action_std = torch.exp(action_log_std) + + return action_mean, action_log_std, action_std + + +class Value(nn.Module): + def __init__(self, num_inputs): + super(Value, self).__init__() + self.affine1 = nn.Linear(num_inputs, 64) + self.affine2 = nn.Linear(64, 64) + self.value_head = nn.Linear(64, 1) + self.value_head.weight.data.mul_(0.1) + self.value_head.bias.data.mul_(0.0) + + def forward(self, x): + x = torch.tanh(self.affine1(x)) + x = torch.tanh(self.affine2(x)) + + state_values = self.value_head(x) + return state_values diff --git a/QNTRPO/quasinewton_approximation_hessian.py b/QNTRPO/quasinewton_approximation_hessian.py new file mode 100644 index 0000000..040923b --- /dev/null +++ b/QNTRPO/quasinewton_approximation_hessian.py @@ -0,0 +1,75 @@ +# Copyright (C) 2019, 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import numpy as np +import torch + + +class quasinewton_approximation_torch(object): + def __init__(self, inithessian, approx_type, memory=30): + + self.inithessian = inithessian # initial hessian matrix for BFGS, scalar multiplying identity for L-BFGS + self.type = approx_type # 0: BFGS, 1: L-BFGS + self.memory = memory # indicates number of vectors that are stored for L-BFGS + + if self.type == 0: # BFGS + self.hessian = inithessian + if self.type == 1: # L-BFGS + self.delta = inithessian + self.m = 0 # number of vectors currently in the limited memory approximation + self.S = 0 # this is the collection of s vectors + self.Y = 0 # this is the collection of y vectors + self.Minv = 0 # this is the matrix in the middle of the limited memory approximation + self.L = 0 # this is the lower triangular matrix + self.d = 0 # this is the diagonal of the SE submatrix in M + self.STS = 0 # this is the diagonal of the SE submatrix in M + + def update(self, s, y): + + ys = torch.dot(y, s) + if ys <= 1e-3: + return + rho = 1.0 / torch.dot(y, s) + + print("rho", rho) + + if self.type == 0: # BFGS + Hess_s = torch.matmul(self.hessian, s) + sT_Hess_s = torch.dot(s, Hess_s) + self.hessian = self.hessian + rho * torch.ger(y, y) - 1 / (sT_Hess_s) * torch.ger(Hess_s, Hess_s) + + def hessvec(self, x): + + if self.type == 0: # BFGS + Hess_x = torch.matmul(self.hessian, x) + return Hess_x + + if self.type == 1: # L-BFGS + if self.m == 0: + Hess_x = self.delta * x + print("Hess_x", Hess_x.size()) + return Hess_x + + # TODO: pythonify this + x1 = torch.cat( + (self.delta * torch.matmul(torch.t(self.S), x), torch.matmul(torch.t(self.Y), x)), 0 + ) # np.vstack((self.delta*np.matmul(self.S.T,x),np.matmul(self.Y.T,x))) + x2 = torch.matmul(self.Minv, x1) + x3 = self.delta * torch.matmul(self.S, x2[0 : self.m]) + torch.matmul(self.Y, x2[self.m :]) + Hess_x = self.delta * x - x3 + + print("Hess_x", Hess_x.size()) + return Hess_x + + def reset(self): + + if self.type == 1: # L-BFGS + self.delta = self.delta + self.m = 0 # number of vectors currently in the limited memory approximation + self.S = 0 # this is the collection of s vectors + self.Y = 0 # this is the collection of y vectors + self.Minv = 0 # this is the matrix in the middle of the limited memory approximation + self.L = 0 # this is the matrix in the middle of the limited memory approximation + self.d = 0 # this is the diagonal of the SE submatrix in M + self.STS = 0 # this is the diagonal of the SE submatrix in M diff --git a/QNTRPO/replay_memory.py b/QNTRPO/replay_memory.py new file mode 100644 index 0000000..86e0e81 --- /dev/null +++ b/QNTRPO/replay_memory.py @@ -0,0 +1,24 @@ +# Copyright (C) 2019, 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import random +from collections import namedtuple + +Transition = namedtuple("Transition", ("state", "action", "mask", "next_state", "reward")) + + +class Memory(object): + def __init__(self): + self.memory = [] + + def push(self, *args): + + self.memory.append(Transition(*args)) + + def sample(self): + + return Transition(*zip(*self.memory)) + + def __len__(self): + return len(self.memory) diff --git a/QNTRPO/running_stats.py b/QNTRPO/running_stats.py new file mode 100644 index 0000000..2c75e9e --- /dev/null +++ b/QNTRPO/running_stats.py @@ -0,0 +1,86 @@ +# Copyright (C) 2019, 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +from collections import deque + +import numpy as np + + +class RunningStats(object): + def __init__(self, shape): + self._n = 0 + self._M = np.zeros(shape) + self._S = np.zeros(shape) + + def push(self, x): + x = np.asarray(x) + + assert x.shape == self._M.shape + + self._n += 1 + if self._n == 1: + self._M[...] = x + + else: + oldM = self._M.copy() + self._M[...] = oldM + (x - oldM) / self._n + self._S[...] = self._S + (x - oldM) * (x - self._M) + + @property + def n(self): + return self._n + + @property + def mean(self): + return self._M + + @property + def var(self): + return self._S / (self._n - 1) if self._n > 1 else np.square(self._M) + + @property + def std(self): + return np.sqrt(self.var) + + @property + def shape(self): + return self._M.shape + + +class Zfilter: + + """ + y=(x-mean)/std + using running estimates of mean,std + + """ + + def __init__(self, shape, demean=True, destd=True, clip=10.0): + + self.demean = demean + self.destd = destd + self.clip = clip + + self.rs = RunningStats(shape) + + def __call__(self, x, update=True): + + if update: + self.rs.push(x) + + if self.demean: + + x = x - self.rs.mean + + if self.destd: + x = x / (self.rs.std + 1e-8) + + if self.clip: + x = np.clip(x, -self.clip, self.clip) + + return x + + def output_shape(self, input_space): + + return input_space.shape diff --git a/QNTRPO/trust_region_opt_step.py b/QNTRPO/trust_region_opt_step.py new file mode 100644 index 0000000..46a2a3e --- /dev/null +++ b/QNTRPO/trust_region_opt_step.py @@ -0,0 +1,53 @@ +# Copyright (C) 2019, 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import sys + +import numpy as np +import torch +from torch.autograd import Variable + +# sys.path.append("./Second_Order_Method") +from trust_region_opt_torch import * +from utils_trpo import * + + +def update_params_trust_region_step(model, get_loss, get_kl, max_kl, damping): + + """ + + trust_region_step: second order trpo step + model: torch model + get_kl: method kl divergence + max_kl : maximum kl_divergence + damping: scalar parameter to make conjugate gradient method more stable -- refer to CG literature for more clarity. + + """ + + loss = get_loss() + grads = torch.autograd.grad(loss, model.parameters()) + loss_grad = torch.cat([grad.view(-1) for grad in grads]).data + + # ---- Get the current model parameters + + curr_params = get_flat_params_from(model) + + Dogleg_method_object = TR_Optimizer( + model, get_loss, get_kl, damping, max_kl, 1e-3, 10, 1 + ) ## Create the trust region optimization object + + new_params = Dogleg_method_object.solve() + + # new_params=torch.from_numpy(new_params) + # new_params=new_params.view(-1) + set_flat_params_to(model, new_params) + + new_params = get_flat_params_from(model) + + print("L2 norm of change in model parameters......", np.linalg.norm(curr_params.numpy() - new_params.numpy())) + + # print ("New step with improvement in Optimization function,...", success) + + return loss diff --git a/QNTRPO/trust_region_opt_torch.py b/QNTRPO/trust_region_opt_torch.py new file mode 100644 index 0000000..9eb1756 --- /dev/null +++ b/QNTRPO/trust_region_opt_torch.py @@ -0,0 +1,267 @@ +# Copyright (C) 2019, 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +"""Python script for trust region optimization using Dogleg Method""" + +# from main import get_kl +# +import logging + +import numpy as np +import torch + +# from compute_steepest_descent_step import * +from compute_dogleg_step import * +from initialize_iterate import initialize_solution_iterate + +# from lbfgs_approximation import lbfgs_approx +from quasinewton_approximation_hessian import quasinewton_approximation_torch + +# from compute_trust_region_step import trpo_step +from utils_trpo import * + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger(__name__) + + +class TR_Optimizer(object): + def __init__(self, model, f, get_kl, damping, delta0, tol, maxiter, step_type): + """ + n_var: number of variables + problem: class defining the problem and methods fun: for objective, gradf: for gradient of objective + delta0: initial trust region radius + tol: convergence tolerance + maxiter: max number of iterations + + """ + + self.model = model + self.f = f() + + self.fval = f + + curr_params = get_flat_params_from(model) + + self._n_var = int(list(curr_params.size())[0]) # problem.get_nvar() + + self._inithessian = torch.eye(self._n_var) # self._problem.initialize_hessian() + + ## Trust Region Hyperparamaters + self._delta0 = delta0 + self._tol = tol + self._maxiter = maxiter + self._step_type = step_type + + self.damping = damping + self.get_kl = get_kl + + ###set parameters for TR optimization + + self._parameters = { + "tr_ratio_good": 0.75, # if ared/pred >= ratio_good then there is possibility of TR increase + "tr_ratio_bad": 0.1, # if ared/pred < ratio_bad then decrease TR + "tr_ratio_accept": 1e-4, # if ared/pred > ratio_accept then the step is accepted + "tr_step_factor": 0.8, # increase TR when step is 0.8*(current TR) + "tr_delta_small": 1e-5, # threshold when the TR is assumed to have become unacceptably small + "tr_inc_delta": 2.0, # multiplicative factor for TR when increasing + "tr_dec_delta": 0.3, # multiplicative factor for TR when decreasing + "tr_maxdelta": 1 * 1e-1, # max TR + "tr_lbfgs": 0, # 0: BFGS, 1:LBFGS, DONOT INDICATE 1, WE have removed lbfsgs as it wasnt included in the paper. + "tr_lm_kmax": np.min((self._n_var, 30)), + } + + # if 'numpy' in str(type(self._inithessian)):#len(self._inithessian)>1: + # self._parameters['tr_lbfgs']=0 + # else: + # self._parameters['tr_lbfgs']=1 + + self._iter = 0 + self._loop = 1 + + if self._parameters["tr_lbfgs"] == 1: + self._inithessian = 1.0 + + ##initialize the problem class here + self._x0 = get_flat_params_from(model) + self.f0 = f(True).data + self.f0 = self.f0.item() + self.grad0 = torch.autograd.grad(self.f, self.model.parameters(), create_graph=True, retain_graph=True) + + flat_grad = torch.cat([grad.view(-1) for grad in self.grad0]).data + + self.grad0 = flat_grad + + # print (self.grad0) + self._trust_region_hessian0 = torch.eye(self._n_var) + + self.iterate = initialize_solution_iterate(self._x0, self.f0, self.grad0, self._trust_region_hessian0) + + ##initialize hessian approximation + # self.hessian=quasinewton_approximation_torch(self._inithessian,self._parameters['tr_lbfgs'],self._parameters['tr_lm_kmax']) + if self._step_type > 0: + self.hessian = quasinewton_approximation_torch( + self._inithessian, self._parameters["tr_lbfgs"], self._parameters["tr_lm_kmax"] + ) + else: + self.hessian = 0 + + ##initialize trust-region step + self.trpo_step = trpo_step(self._n_var, step_type, model, f, get_kl, damping) + + if self.iterate.error <= tol: + self._loop = 0 + + ##initialize trust-region radius + self._delta = self._delta0 + + def solve(self): + + ## compute_step + while self._loop == 1: + + dx, dx_size, flag_step = self.trpo_step.compute_step(self.iterate, self.hessian, self._delta) # ,Hessvec + + # x_curr=self.iterate.x + # x_curr=x_curr.view(-1) + # set_flat_params_to(self.model,x_curr) + # f_curr=self.fval(True).data + # f_curr=f_curr.numpy() + # print('fval at current point,',f_curr) + + x_new = self.iterate.x + dx + + # xnew=torch.from_numpy(x_new) + xnew = torch.cat([x.view(-1) for x in x_new]).data + # xnew=xnew.view(-1) + + set_flat_params_to(self.model, xnew) + f_new = self.fval(True).data + + print("fval at new point", f_new) + grads = torch.autograd.grad(self.f, self.model.parameters(), create_graph=False, retain_graph=True) + g_new = torch.cat([grad.view(-1) for grad in grads]).data + + act_dec = self.iterate.f - f_new + + if self._step_type == 0: + a = torch.cat([grad.view(-1) for grad in self.iterate.g]).data + pre_dec = -torch.dot(a, dx) + + # pre_dec=pre_dec.numpy() + print("Predicted decrease and actual decrease difference", pre_dec - act_dec) + else: + Hdx = self.hessian.hessvec(dx) + pre_dec = -np.vdot(self.iterate.g, dx) - 0.5 * np.vdot(dx, Hdx) + if pre_dec <= 0: + logging.debug( + "flag = %c gdx = %e quad = %e", flag_step, np.vdot(self.iterate.g, dx), 0.5 * np.vdot(dx, Hdx) + ) + + # ratio=act_dec/(1e-16+pre_dec) + ratio = act_dec / (pre_dec + 1e-16) + + ## Check progress of solution + accept_flag = 1 + if act_dec <= 0: + print(self._iter, " act_dec=", act_dec, " pre_dec=", pre_dec) + # print (act_) + if act_dec <= 0 or ratio <= self._parameters["tr_ratio_accept"]: + accept_flag = 0 + + delta_old = self._delta + delta_change = 0 + + if act_dec >= 0 and ratio >= self._parameters["tr_ratio_good"]: + + if dx_size >= self._parameters["tr_step_factor"] * self._delta: + + self._delta = min(self._parameters["tr_maxdelta"], self._delta * self._parameters["tr_inc_delta"]) + delta_change = 1 + + elif ratio >= self._parameters["tr_ratio_bad"] and ratio <= self._parameters["tr_ratio_good"]: + pass ## do nothing + + else: + self._delta = self._delta * self._parameters["tr_dec_delta"] + delta_change = -1 + + s = dx + + # print('size of s-----',s.shape,self.iterate.g.shape) + + g_earlier = g_new + + y = g_new - self.iterate.g + + # print('size of y---------',y.shape) + + if accept_flag == 1: + + self.iterate.x = x_new + self.iterate.f = f_new + self.iterate.g = g_new + # self.iterate.A=self._trust_region_hessian0#self._problem.trust_region_hess(x_new) #(self.iterate.x) + self.iterate.error = self.compute_error(g_new) + + ### update hessian + if self._step_type > 0: + self.hessian.update(s, y) + + if ( + self._parameters["tr_lbfgs"] == 1 + and act_dec <= 0 + and (delta_change == 0 or self._delta <= 10 * self._parameters["tr_delta_small"]) + ): + + ## indication that there is no progress. Get rid of the vectors and start over + self.hessian.reset() + # resetting the trust region size if this becomes too small + if self._delta <= 10 * self._parameters["tr_delta_small"]: + self._delta = self._delta0 + + self._iter += 1 + + ## print statistics + nrmHess = 0.0 + if self._step_type > 0 and self._parameters["tr_lbfgs"] == 0: + nrmHess = torch.norm(self.hessian.hessian) + if np.mod(self._iter, 10) == 1: + + logging.debug("Iteration Objective ||g|| Ratio ||dx|| Accept delta Change nrm(Hess)") + + logging.debug( + "%d %e %e %e %e%c %d %e %d %e", + self._iter, + self.iterate.f, + self.iterate.error, + ratio, + dx_size, + flag_step, + accept_flag, + delta_old, + delta_change, + nrmHess, + ) + # logging.debug (self._iter, self.iterate.f, self.iterate.error, ratio, dx_size, accept_flag, delta_old, delta_change, nrmHess) + + if self.iterate.error <= self._tol: + self._loop = 0 + + break + elif self._iter >= self._maxiter: + self._loop = 1 + break + elif self._delta <= self._parameters["tr_delta_small"] and accept_flag == 0: + self._loop = 2 + break + + return xnew + + def compute_error(self, g): + + a = torch.norm(g) + + return a diff --git a/QNTRPO/trust_region_step.py b/QNTRPO/trust_region_step.py new file mode 100644 index 0000000..fd7b994 --- /dev/null +++ b/QNTRPO/trust_region_step.py @@ -0,0 +1,381 @@ +# Copyright (C) 2019, 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +"""Python script to compute the trust region step """ + +# from conjugate_grad_solver import conjugate_gradients +# import scipy.sparse +import logging + +import numpy as np +import torch + +logging.basicConfig(level=logging.DEBUG) + +from conjugate_grad_solution import conjugate_gradient + +logger = logging.getLogger(__name__) + + +class trpo_step(object): + def __init__(self, n_var, type_option, model, get_loss, get_kl, damping): + + self._n_var = n_var + self.type = type_option + + self.model = model + self.get_kl = get_kl + self.get_loss = get_loss + self.damping = damping + + # self.conjugate_gradient_solver=conjugate_gradients + + def Fvp(self, v): + get_kl = self.get_kl + loss = self.get_loss() + damping = self.damping + model = self.model + + kl = get_kl() + kl = kl.mean() + + grads = torch.autograd.grad(kl, model.parameters(), create_graph=True) + flat_grad_kl = torch.cat([grad.view(-1) for grad in grads]) + + kl_v = (flat_grad_kl * Variable(v)).sum() + + grads = torch.autograd.grad(kl_v, model.parameters()) + flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).data + + return flat_grad_grad_kl + v * damping + + def compute_step(self, iterate, hessian, delta): + + if self.type == 1: + ## Implement Dogleg method + + """ + + get the Newton direction, check to see if it is within the trust-region + original QP: min g^T d + 0.5*d^T*H*d s.t. d^T A d <= delta^2 + transformed QP: min ghat^T dhat + 0.5*dhat^T*Hhat*dhat s.t. dhat^T dhat <= delta^2 + where A = L*L^T, dhat = L^T*d, ghat = L^{-1}*g, Hhat = L^{-1}*H*L^{-T} + newton step on transformed: dxNhat = -Hhat grad_hat = -L^T*H^{-1}*g = L^T*dxN + where dxN is the original newton step + So checking if dxNhat^T*dxNhat <= delta^2 is equivalent to dxN^T*A*dxN <= delta^2 + + """ + ###### CG + dxN, flag_cg = self.conjugate_gradient(hessian, -iterate.g, 1e-6, 2 * self._n_var, 0) + + print("flag_cg", flag_cg) + if flag_cg == 0: + + a = torch.from_numpy(dxN) + a = a.view(-1) + Fvp = iterate.A + AdxN = Fvp(a) # np.matmul(iterate.A,dxN) ###### FVP + AdxN = AdxN.numpy() + dxN_size = np.sqrt(np.vdot(dxN, AdxN)) ### change the variable A + + if dxN_size <= delta: + dx = dxN + dx_size = dxN_size + flag = "N" + return dx, dx_size, flag + + """ + Get the steepest descent direction taking A into account + original QP: min g^T d + 0.5*d^T*H*d s.t. d^T A d <= delta^2 + transformed QP: min ghat^T dhat + 0.5*dhat^T*Hhat*dhat s.t. dhat^T dhat <= delta^2 + where A = L*L^T, dhat = L^T*d, ghat = L^{-1}*g, Hhat = L^{-1}*H*L^{-T} + dxShat = -ghat + We want to find the step size alphahat that minimizes transformed QP + alphahat = (ghat^T*ghat)/(ghat^T*Hhat*ghat) = (g^T*A^{-1}*g)/(g^T*A^{-1}*H*A^{-1}*g) + norm of the transformed step size is alphahat*norm(dxShat) = alphahat*sqrt(g^T*A^{-1}*g) + If this step size is >= delta then the transformed step is -delta/(sqrt(g^T*A^{-1}*g))*ghat + The original step -delta/(sqrt(g^T*A^{-1}*g))*Ainvg + + """ + ##### Do CG + # Ainvg, flag_Ag =self.conjugate_gradient(iterate.A,iterate.g,1e-6,2*self._n_var,1) #self.conjugate_gradient_solver(A,g,2*self._n_var) + + a = torch.from_numpy(iterate.g) + a = a.view(-1) + + # print("size of gradient", a.size()) + Ainvg, flag_Ag = conjugate_gradient(self.Fvp, a, 2 * self._n_var, 1e-6) + + Ainvg = Ainvg.numpy() + + print("Ainvg size", Ainvg.shape) + # flag_Ag=flag_Ag.numpy() + + print("vector product", np.vdot(Ainvg, iterate.g), "flag_Ag", flag_Ag) + alpha_hat = 0 + ghat_nrm = 0 + + if flag_Ag == 0: + AgBAg = np.vdot(Ainvg, hessian.hessvec(Ainvg)) + alpha_hat = np.vdot(Ainvg, iterate.g) / AgBAg + + ############### + x1 = np.vdot(Ainvg, iterate.g) * alpha_hat + + x2 = AgBAg * 0.5 * alpha_hat**2 + + x = -x1 + x2 + + print("Decrease in function", x) + ##################################### + + ghat_nrm = np.sqrt(np.vdot(Ainvg, iterate.g)) + + dxShat_size = alpha_hat * ghat_nrm + + if flag_cg > 0 or dxShat_size >= delta: + + dx = -delta / ghat_nrm * Ainvg + dx_size = delta + flag = "S" + return dx, dx_size, flag + + if flag_cg > 0 and dxShat_size < delta: + dx = -alpha_hat * Ainvg + dx_size = dxShat_size + flag = "S0" + return dx, dx_size, flag + + ## if failed to compute the Newton Step or the steepest descent direction, resort to this + + if flag_cg > 0 or flag_Ag > 0: + dxS = -np.vdot(iterate.g, iterate.g) / np.vdot(iterate.g, hessian.hessvec(iterate.g)) * iterate.g + Fvp = iterate.A + b = torch.from_numpy(dxS) + b = b.view(-1) + + AinvdxS = Fvp(b) + AinvdxS = AinvdxS.numpy() + + ###---------------------------------- + dxS_size = np.sqrt(np.vdot(dxS, AinvdxS)) #######FVPnp.matmul(iterate.A,dxS + dx = delta / dxS_size * dxS + dx_size = delta + flag = "s" + return dx, dx_size, flag + + """ + get the dogleg step taking A into account + original QP: min g^T d + 0.5*d^T*H*d s.t. d^T A d <= delta^2 + transformed QP: min ghat^T dhat + 0.5*dhat^T*Hhat*dhat s.t. dhat^T dhat <= delta^2 + where A = L*L^T, dhat = L^T*d, ghat = L^{-1}*g, Hhat = L^{-1}*H*L^{-T} + dxShat = -alphahat*ghat + dxNhat = L^T*dxN + Find alpha s.t. ||dxShat + alpha*(dxNhat - dxShat)||^2 = delta^2 + equiv. to solving ||L^T(L^{-T}*dxShat + alpha*(dxN - L^{-T}*dxShat)||^2 = delta^2 + equiv. to solving ||-alphahat*Ainvg + alpha*(dxN + alphahat*Ainvg)||^2_A = delta^2 + form the quadratic equation + + + """ + + Ainvg = np.reshape(Ainvg, [len(Ainvg), 1]) + + dxNAinvg = dxN + alpha_hat * Ainvg + + print("dxNAinvg size", dxN.shape) + Fvp = iterate.A + a = torch.from_numpy(dxNAinvg) + a = a.view(-1) + + print(a.size()) + + atimesdxNAing = Fvp(a) + atimesdxNAing = atimesdxNAing.numpy() + + a_quad = np.vdot(dxNAinvg, atimesdxNAing) ####### FVPnp.matmul(iterate.A,dxNAinvg) + b_quad = -2 * alpha_hat * np.vdot(iterate.g, dxNAinvg) + c_quad = alpha_hat**2 * np.vdot(Ainvg, iterate.g) - delta**2 + + ## Newton step and steepest descent are parallel + + if a_quad <= 1e-6: + + dx = -delta / ghat_nrm * Ainvg + dx_size = delta + + alpha = np.roots([a_quad, b_quad, c_quad]) + + alpha_opt = np.max(alpha) + + dx = -alpha_hat * Ainvg + alpha_opt * dxNAinvg + + dx_torch = torch.from_numpy(dx) + dx_torch = dx_torch.view(-1) + + Fvp = iterate.A + Atimesdxtorch = Fvp(dx_torch) + + Atimesdxtorch = Atimesdxtorch.numpy() + dx_size = np.sqrt(np.vdot(dx, Atimesdxtorch)) ######## FVP np.matmul(iterate.A,dx) + + if alpha_opt < 0: + + logging.debug("Error in computing the dogleg step") + + dx = [] + + dx_size = 0 + + if abs(dx_size - delta) >= 1e-2: + logging.debug("Error in computing Dogleg Step") + + flag = "D" + return dx, dx_size, flag + + if self.type == 0: + + ## Compute scaled steepest descent #### DO CG + # [Ainvg,flag_Ag] = self.conjugate_gradient(iterate.A,iterate.g,1e-6,2*self._n_var,1) + a = torch.from_numpy(iterate.g) + a = a.view(-1) + + Ainvg, flag_Ag = conjugate_gradient(self.Fvp, -a, 2 * self._n_var, 1e-6) + # Ainvg=Ainvg.numpy() + # flag_Ag=flag_Ag.numpy() + + ## do linsearch + vecpdt = torch.dot(Ainvg, -a) + + success, new_params = linesearch(model, get_loss, prev_params, fullstep, vecpdt) + + print("flag_Ag", flag_Ag) + + if flag_Ag == 0: + print("In the scaling loop") + + vecpdt = torch.dot(Ainvg, -a) + vecpdt = vecpdt.numpy() + scaledx = delta / np.sqrt(vecpdt) + + print("Scale dx", scaledx) + dx = 1 * Ainvg.numpy() # scaledx + else: + ##### Do FVP + Fvp = iterate.A + b = torch.from_numpy(iterate.g) + b = b.view(-1) + Ag = Fvp(b) + Ag = Ag.numpy() + # Ag = np.matmul(iterate.A,iterate.g) + scaledx = delta / np.sqrt(np.vdot(iterate.g, Ag)) + dx = -scaledx * iterate.g + + dx_size = delta + flag = "S" + + return dx, dx_size, flag + + if self.type == 2: + + ## Compute combined step + [dx, flag, flag_step] = self.conjugate_gradient_steihaug(hessian, iterate.g, 1e-6, 2 * self._n_var, delta) + + dx_size = np.linalg.norm(dx) + flag = "C" + + return dx, dx_size, flag_step + + def conjugate_gradient(self, hessOrMat, b, residual_tol, nsteps, flag): + + ## flag == 0: hessian, 1: matrix + x = np.zeros([len(b), 1]) + nrmb = np.linalg.norm(b) + if flag == 0: + r = -b + hessOrMat.hessvec(x) + else: + r = -b + np.matmul(hessOrMat, x) + p = -np.copy(r) + rdotr = np.vdot(r, r) # torch.dot(r,r) + # print("Initial residual = ",np.linalg.norm(r)) + flag_cg = 1 + if np.sqrt(rdotr) <= residual_tol or nrmb <= residual_tol or np.sqrt(rdotr) / nrmb <= residual_tol: + flag_cg = 0 + return x, flag_cg + + for i in range(nsteps): + if flag == 0: + _Avp = hessOrMat.hessvec(p) + else: + _Avp = np.matmul(hessOrMat, p) + p_Avp = np.vdot(p, _Avp) + alpha = rdotr / (p_Avp) # torch.dot(p,_Avp) + x += alpha * p + r += alpha * _Avp + + new_rdotr = np.vdot(r, r) # torch.dot(r,r) + + beta = new_rdotr / rdotr + + p = -r + beta * p + + rdotr = new_rdotr + + if np.sqrt(rdotr) <= residual_tol or np.sqrt(rdotr) / nrmb <= residual_tol: + flag_cg = 0 + break + # print("CG: ",rdotr," its ",i) + + return x, flag_cg + + def conjugate_gradient_steihaug(self, hessOrMat, b, residual_tol, nsteps, delta): + + ## flag == 0: hessian, 1: matrix + z = np.zeros([len(b), 1]) + r = np.copy(b) + d = -np.copy(r) + nrmb = np.linalg.norm(b) + rdotr = np.vdot(r, r) # torch.dot(r,r) + # print("Initial residual = ",np.linalg.norm(r)) + flag_cg = 1 + if np.sqrt(rdotr) <= residual_tol or nrmb <= residual_tol or np.sqrt(rdotr) / nrmb <= residual_tol: + flag_cg = 0 + flag_step = "S" + return z, flag_cg, flag_step + + for i in range(nsteps): + _Avd = hessOrMat.hessvec(d) + d_Avd = np.vdot(d, _Avd) + alpha = rdotr / d_Avd # torch.dot(p,_Avp) + z1 = z + alpha * d + if np.linalg.norm(z1) >= delta: + a_quad = np.vdot(d, d) + b_quad = 2 * np.vdot(d, z) + c_quad = np.vdot(z, z) - delta**2 + alpha = np.roots([a_quad, b_quad, c_quad]) + alpha_opt = np.max(alpha) + z = z + alpha_opt * d + flag_cg = 0 + flag_step = "C" + # logging.debug("qCG: %e its %d x1 %e",rdotr,i,np.linalg.norm(z1)) + return z, flag_cg, flag_step + + r += alpha * _Avd + + new_rdotr = np.vdot(r, r) # torch.dot(r,r) + + beta = new_rdotr / rdotr + + d = -r + beta * d + + rdotr = new_rdotr + z = z1 + + if np.sqrt(rdotr) <= residual_tol or np.sqrt(rdotr) / nrmb <= residual_tol: + flag_step = "N" + flag_cg = 0 + break + # logging.debug("CG: %e its %d res %e",rdotr,i,np.linalg.norm(hessOrMat.hessvec(z)-b)) + + return z, flag_cg, flag_step diff --git a/QNTRPO/utils_trpo.py b/QNTRPO/utils_trpo.py new file mode 100644 index 0000000..47bbc74 --- /dev/null +++ b/QNTRPO/utils_trpo.py @@ -0,0 +1,85 @@ +# Copyright (C) 2019, 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters + +use_cuda = torch.cuda.is_available() + + +def normal_entropy(std): + + var = std.pow(2) + entropy = 0.5 + 0.5 * torch.log(2 * var * math.pi) + + return entropy.sum(1, keepdim=True) + + +def normal_log_density(x, mean, log_std, std): + var = std.pow(2) + log_density = -(x - mean).pow(2) / (2 * var) - 0.5 * math.log(2 * math.pi) - log_std + + return log_density.sum(1, keepdim=True) + + +def get_flat_params_from(model): + + params = [] + for param in model.parameters(): + params.append(param.data.view(-1)) + + flat_params = torch.cat(params) + + return flat_params + + +def set_flat_params_to(model, flat_params): + + prev_ind = 0 + + for param in model.parameters(): + flat_size = int(np.prod(list(param.size()))) + + param.data.copy_(flat_params[prev_ind : prev_ind + flat_size].view(param.size())) + + prev_ind += flat_size + + +def get_flat_grad_from(net, grad_grad=False): + + grads = [] + + for param in net.parameters(): + if grad_grad: + grads.append(param.grad.grad.view(-1)) + + else: + + grads.append(param.grad.view(-1)) + + flat_grad = torch.cat(grads) + + return flat_grad + + +def Variable(tensor, *args, **kwargs): + + if use_cuda: + return torch.autograd.Variable(tensor, *args, **kwargs).cuda() + else: + return torch.autograd.Variable(tensor, *args, **kwargs) + + +def Tensor(nparray): + + if use_cuda: + torch.tensor(nparray).cuda() + else: + return torch.Tensor(nparray) diff --git a/README.md b/README.md new file mode 100644 index 0000000..696996d --- /dev/null +++ b/README.md @@ -0,0 +1,90 @@ + + +# QNTRPO: Quasi-Newton Trust Region Policy Optimization + +System requirements: +The code has been tested on these environments. + +1. Ubuntu 16.04 LTS +2. Python 3.6.7 (will not work on Python 3.6.0 due to some issues of pytorch and python 3.6.0) +3. Torch 1.1.0 (the most recent version of pytorch will work ) +4. Mujoco_py==1.50 +5. Gym + +## Features + +QNTRPO solves the Policy Optimization problem that arises in Reinforcement Learning using a Quasi-Newton Trust Region algorithm. + +## Installation + +The code depends on external libraries. Install the software following the instructions below. We are describing the installation in a virtual environment. +``` +conda create -n qntrpo python=3.11 anaconda + +source activate qntrpo + +conda install pytorch +``` + +Install Mujoco and mujoco-py following the instructions in https://github.com/openai/mujoco-py (License: `MIT`) + +Install Gym following the instructions in https://github.com/openai/gym (License: `MIT`) + +## Usage + +If a user wants to change the trust region radius for optimization, they should change the parameter "tr_maxdelta" on line 67 in the code "trust_region_opt_torch.py". The current value is 1e-1. It is suggested to run the code with this value. The performance of the algorithm on other values have not been fully tested yet. + + A different batch size could be used by adding another argument while calling the code, --batch-size N, where (N is an integer say 25000), i.e., + + ``` +python main.py --env-name "Walker2d-v2" --seed 1243 --batch-size 25000 +``` + +## Testing + +QNTRPO algorithm can be tested by running the following in a terminal (for example for Walker2d and seed, say 1243). +``` +python main.py --env-name "Walker2d-v2" --seed 1243 +``` + +## Citation + +If you use the software, please cite the following ([TR2019-120](https://www.merl.com/publications/TR2019-120)): + +```bibTeX +@inproceedings{Jha2019oct, +author = {Jha, Devesh K. and Raghunathan, Arvind and Romeres, Diego}, +title = {Quasi-Newton Trust Region Policy Optimization}, +booktitle = {Conference on Robot Learning (CoRL)}, +year = 2019, +editor = {Leslie Pack Kaelbling and Danica Kragic and Komei Sugiura}, +pages = {945--954}, +month = oct, +publisher = {Proceedings of Machine Learning Research}, +url = {https://www.merl.com/publications/TR2019-120} +} +``` + +## Contact + +Please contact one of us Devesh K Jha (jha@merl.com), Arvind U Raghunathan (raghunathan@merl.com), or Diego Romeres (romeres@merl.com). + +## Contributing + +See [CONTRIBUTING.md](CONTRIBUTING.md) for our policy on contributions. + +## License + +Released under `AGPL-3.0-or-later` license, as found in the [LICENSE.md](LICENSE.md) file. + +All files: + +``` +Copyright (C) 2019, 2023 Mitsubishi Electric Research Laboratories (MERL). + +SPDX-License-Identifier: AGPL-3.0-or-later +``` diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3d26d6a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +# Copyright (C) 2020, 2023 Mitsubishi Electric Research Laboratories (MERL). +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +pytorch==1.1.0 +gym==0.15.4 +tensorboard==2.0.2 +numpy==1.18.0