diff --git a/.github/workflows/requirements-dev.txt b/.github/workflows/requirements-dev.txt new file mode 100644 index 0000000..42bce9e --- /dev/null +++ b/.github/workflows/requirements-dev.txt @@ -0,0 +1,5 @@ +# Copyright (c) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +pre-commit diff --git a/.github/workflows/static_checks.yaml b/.github/workflows/static_checks.yaml new file mode 100644 index 0000000..1a01346 --- /dev/null +++ b/.github/workflows/static_checks.yaml @@ -0,0 +1,76 @@ +# Copyright (c) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +name: Static code checks + +on: # yamllint disable-line rule:truthy + pull_request: + push: + branches: + - '**' + tags-ignore: + - '**' + +env: + LICENSE: AGPL-3.0-or-later + FETCH_DEPTH: 1 + FULL_HISTORY: 0 + SKIP_WORD_PRESENCE_CHECK: 0 + +jobs: + static-code-check: + if: endsWith(github.event.repository.name, 'private') + + name: Run static code checks + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + + steps: + - name: Setup history + if: github.ref == 'refs/heads/oss' + run: | + echo "FETCH_DEPTH=0" >> $GITHUB_ENV + echo "FULL_HISTORY=1" >> $GITHUB_ENV + + - name: Setup version + if: github.ref == 'refs/heads/melco' + run: | + echo "SKIP_WORD_PRESENCE_CHECK=1" >> $GITHUB_ENV + + - name: Check out code + uses: actions/checkout@v3 + with: + fetch-depth: ${{ env.FETCH_DEPTH }} # '0' to check full history + + - name: Set up environment + run: git config user.email github-bot@merl.com + + - name: Set up python + uses: actions/setup-python@v4 + with: + python-version: 3 + cache: 'pip' + cache-dependency-path: '.github/workflows/requirements-dev.txt' + + - name: Install python packages + run: pip install -r .github/workflows/requirements-dev.txt + + - name: Ensure lint and pre-commit steps have been run + uses: pre-commit/action@v3.0.0 + + - name: Check files + uses: merl-oss-private/merl-file-check-action@v1 + with: + license: ${{ env.LICENSE }} + full-history: ${{ env.FULL_HISTORY }} # If true, use fetch-depth 0 above + skip-word-presence-check: ${{ env.SKIP_WORD_PRESENCE_CHECK }} + + - name: Check license compatibility + if: github.ref != 'refs/heads/melco' + uses: merl-oss-private/merl_license_compatibility_checker@v1 + with: + input-filename: environment.yml + license: ${{ env.LICENSE }} diff --git a/.gitignore b/.gitignore index b6e4761..2ffd879 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +# Copyright (c) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..bfabfe7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,64 @@ +# Copyright (c) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later +# +# Pre-commit configuration. See https://pre-commit.com + +default_language_version: + python: python3 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-yaml + - id: check-added-large-files + args: ['--maxkb=5000'] + + - repo: https://gitlab.com/bmares/check-json5 + rev: v1.0.0 + hooks: + - id: check-json5 + + - repo: https://github.com/homebysix/pre-commit-macadmin + rev: v1.12.3 + hooks: + - id: check-git-config-email + args: ['--domains', 'merl.com'] + + - repo: https://github.com/psf/black + rev: 22.12.0 + hooks: + - id: black + args: + - --line-length=120 + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files", "--line-length", "120", "--skip-gitignore"] + + # Uncomment to use pyupgrade (https://github.com/asottile/pyupgrade) to automatically upgrade syntax for newer python + # - repo: https://github.com/asottile/pyupgrade + # rev: v3.3.1 + # hooks: + # - id: pyupgrade + + # To stop flake8 error from causing a failure, use --exit-zero. By default, pre-commit will not show the warnings, + # so use verbose: true to see them. + - repo: https://github.com/pycqa/flake8 + rev: 5.0.4 + hooks: + - id: flake8 + # Black compatibility, Eradicate options + args: ["--max-line-length=120", "--extend-ignore=E203", + "--eradicate-whitelist-extend", "eradicate:\\s*no", + "--exit-zero"] + verbose: true + additional_dependencies: [ + # https://github.com/myint/eradicate, https://github.com/wemake-services/flake8-eradicate + "flake8-eradicate" + ] diff --git a/.reuse/dep5 b/.reuse/dep5 new file mode 100644 index 0000000..226fc34 --- /dev/null +++ b/.reuse/dep5 @@ -0,0 +1,13 @@ +Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ + +Files: input_example/skins/*.png +Copyright: 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +License: AGPL-3.0-or-later + +Files: input_example/hairs/*.png +Copyright: 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +License: AGPL-3.0-or-later + +Files: utils/*.png +Copyright: 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +License: AGPL-3.0-or-later diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..95eb1e8 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,9 @@ + +# Contributing + +Sorry, but we do not currently accept contributions in the form of pull requests +to this repository. However, you are welcome to post issues (bug reports, feature requests, questions, etc). diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..6bb5339 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,661 @@ +GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + +Copyright (C) 2007 Free Software Foundation, Inc. +Everyone is permitted to copy and distribute verbatim copies +of this license document, but changing it is not allowed. + + Preamble + +The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + +The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + +When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + +Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + +A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + +The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + +An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + +The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + +0. Definitions. + +"This License" refers to version 3 of the GNU Affero General Public License. + +"Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + +"The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + +To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + +A "covered work" means either the unmodified Program or a work based +on the Program. + +To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + +To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + +An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + +1. Source Code. + +The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + +A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + +The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + +The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + +The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + +The Corresponding Source for a work in source code form is that +same work. + +2. Basic Permissions. + +All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + +You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + +Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + +3. Protecting Users' Legal Rights From Anti-Circumvention Law. + +No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + +When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + +4. Conveying Verbatim Copies. + +You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + +You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + +5. Conveying Modified Source Versions. + +You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + +a) The work must carry prominent notices stating that you modified +it, and giving a relevant date. + +b) The work must carry prominent notices stating that it is +released under this License and any conditions added under section +7. This requirement modifies the requirement in section 4 to +"keep intact all notices". + +c) You must license the entire work, as a whole, under this +License to anyone who comes into possession of a copy. This +License will therefore apply, along with any applicable section 7 +additional terms, to the whole of the work, and all its parts, +regardless of how they are packaged. This License gives no +permission to license the work in any other way, but it does not +invalidate such permission if you have separately received it. + +d) If the work has interactive user interfaces, each must display +Appropriate Legal Notices; however, if the Program has interactive +interfaces that do not display Appropriate Legal Notices, your +work need not make them do so. + +A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + +6. Conveying Non-Source Forms. + +You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + +a) Convey the object code in, or embodied in, a physical product +(including a physical distribution medium), accompanied by the +Corresponding Source fixed on a durable physical medium +customarily used for software interchange. + +b) Convey the object code in, or embodied in, a physical product +(including a physical distribution medium), accompanied by a +written offer, valid for at least three years and valid for as +long as you offer spare parts or customer support for that product +model, to give anyone who possesses the object code either (1) a +copy of the Corresponding Source for all the software in the +product that is covered by this License, on a durable physical +medium customarily used for software interchange, for a price no +more than your reasonable cost of physically performing this +conveying of source, or (2) access to copy the +Corresponding Source from a network server at no charge. + +c) Convey individual copies of the object code with a copy of the +written offer to provide the Corresponding Source. This +alternative is allowed only occasionally and noncommercially, and +only if you received the object code with such an offer, in accord +with subsection 6b. + +d) Convey the object code by offering access from a designated +place (gratis or for a charge), and offer equivalent access to the +Corresponding Source in the same way through the same place at no +further charge. You need not require recipients to copy the +Corresponding Source along with the object code. If the place to +copy the object code is a network server, the Corresponding Source +may be on a different server (operated by you or a third party) +that supports equivalent copying facilities, provided you maintain +clear directions next to the object code saying where to find the +Corresponding Source. Regardless of what server hosts the +Corresponding Source, you remain obligated to ensure that it is +available for as long as needed to satisfy these requirements. + +e) Convey the object code using peer-to-peer transmission, provided +you inform other peers where the object code and Corresponding +Source of the work are being offered to the general public at no +charge under subsection 6d. + +A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + +A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + +"Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + +If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + +The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + +Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + +7. Additional Terms. + +"Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + +When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + +Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + +a) Disclaiming warranty or limiting liability differently from the +terms of sections 15 and 16 of this License; or + +b) Requiring preservation of specified reasonable legal notices or +author attributions in that material or in the Appropriate Legal +Notices displayed by works containing it; or + +c) Prohibiting misrepresentation of the origin of that material, or +requiring that modified versions of such material be marked in +reasonable ways as different from the original version; or + +d) Limiting the use for publicity purposes of names of licensors or +authors of the material; or + +e) Declining to grant rights under trademark law for use of some +trade names, trademarks, or service marks; or + +f) Requiring indemnification of licensors and authors of that +material by anyone who conveys the material (or modified versions of +it) with contractual assumptions of liability to the recipient, for +any liability that these contractual assumptions directly impose on +those licensors and authors. + +All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + +If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + +Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + +8. Termination. + +You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + +However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + +Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + +Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + +9. Acceptance Not Required for Having Copies. + +You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + +10. Automatic Licensing of Downstream Recipients. + +Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + +An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + +You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + +11. Patents. + +A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + +A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + +Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + +In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + +If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + +If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + +A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + +Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + +12. No Surrender of Others' Freedom. + +If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + +13. Remote Network Interaction; Use with the GNU General Public License. + +Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + +Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + +14. Revised Versions of this License. + +The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + +Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + +If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + +Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + +15. Disclaimer of Warranty. + +THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + +16. Limitation of Liability. + +IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + +17. Interpretation of Sections 15 and 16. + +If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + +How to Apply These Terms to Your New Programs + +If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + +To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + +Copyright (C) + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as published +by the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + +If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + +You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/LICENSES/Apache-2.0.txt b/LICENSES/Apache-2.0.txt new file mode 100644 index 0000000..e8f691d --- /dev/null +++ b/LICENSES/Apache-2.0.txt @@ -0,0 +1,204 @@ + + + +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/LICENSES/MIT.txt b/LICENSES/MIT.txt new file mode 100644 index 0000000..a450b2b --- /dev/null +++ b/LICENSES/MIT.txt @@ -0,0 +1,134 @@ +MIT License + +Copyright (c) 2022 Omri Avrahami + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +MIT License + +Copyright (c) 2020 Assaf Shocher + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +MIT License + +Copyright (c) Microsoft Corporation. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE + +MIT License + +Copyright (c) 2021 OpenAI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +MIT License + +Copyright (c) 2021 Alec Radford + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +MIT License + +Copyright (c) 2021 Po-Hsun-Su + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 3819a44..977f9ca 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,204 @@ -# DiffusionFace -Convert a pretrained unconditional diffusion model to a conditional one using steering networks + + +# Steered Diffusion (ICCV 2023) + + +This repository contains the implementation of the paper: +> **Steered Diffusion: A Generalized Framework for Plug-and-Play Conditional Face Synthesis**
+> [Nithin Gopalakrishnan Nair](https://nithin-gk.github.io/), [Anoop Cherian](https://www.merl.com/people/cherian), [Suhas Lohit](https://suhaslohit.github.io),[Ye Wang](https://www.merl.com/people/yewang), [Toshiaki Koike-Akino](https://www.merl.com/people/koike), [Vishal M Patel](https://engineering.jhu.edu/vpatel36/vishal-patel), [Tim K Marks](https://www.merl.com/people/tmarks) + +IEEE/CVF International Conference on Computer Vision (**ICCV**), 2023 + +From [Mitsubishi Electric Research Labs](https://www.merl.com/) and [VIU Lab](https://engineering.jhu.edu/vpatel36/), Johns Hopkins University + + +[[Paper](https://openaccess.thecvf.com/content/ICCV2023/html/Nair_Steered_Diffusion_A_Generalized_Framework_for_Plug-and-Play_Conditional_Image_Synthesis_ICCV_2023_paper.html)] | +[[Project Page](https://www.merl.com/demos/steered-diffusion)] + +Keywords: Zero Shot Generation, Conditional Face Generation, Multimodal Face generation, Text to image generation, Diffusion based Face Generation + +## Summary + + + +We propose **Steered Diffusion**, where users can perform zero shot conditional generation using conditional generation. + *(a) Linear Inverse Problems*. Can generate clean images linear conditions like Super resolution, Colorization and Inpainting . + *(b) Complex Network based conditioning*. Users can give generic conditions like segmentation maps, identity image and text based generation. + +### Contributions: + +- We propose steered diffusion, a general plug-and-play framework that can utilize various pre-existing models to steer an unconditional diffusion model. +- We present the first work applicable for both label-level synthesis and image-to-image translation tasks, and we demonstrate its effectiveness for various applications. +- We propose an implicit conditioning-based sampling strategy that significantly boosts the performance of conditional sampling from the unconditional diffusion models compared to previous methods. +- We introduce a new strategy that utilizes multiple steps of projected gradient descent to improve sample quality. +

+ Centered Image +

+ +## Environment setup + + +``` +conda env create -f environment.yml +``` + + +## Pretrained models: + +Please download the pretrained models using +``` +python utils/download_models.py +``` + + +# Testing on custom datasets + +## Data Preparation +You can test on any custom datasets by arranding the data in the following format. Please note that you can choose to give either one or more modalities +``` + ├── data + | ├── images + | └── masks +``` + +## Testing code on images +For testing purposes, for simplicity, the code expects a clean RGB image as input for the variable --img_path mentioned below. The degraded image is extracted from the input by the code based on the condition being tested. Please modify the code as needed to inpout a degraded image directly. Also note that, for inpainting, a separate mask file needs to be provided. +For testing the conditions: +``` +grayscale: converts a grayscale image to an RGB image +SR: super-resolution +inpaint +Identity: generates new images with the same identity as the given image +Semantics: generated new images with the same semantic face parsing map as the given image +``` +please use the command +``` +python steered_generate.py --config configs/diffusion_config.yml --img_path /path/to/image --mask_path /path/to/mask --condition "your condition" +``` +For testing image editing, please use the command +``` +python steered_generate.py --config configs/diffusion_config.yml --img_path /path/to/image --mask_path /path/to/mask --condition "editing" --editing_text " the text prompt to add to image" + +``` +The final output image will be saved in the results directory in the format "condition image|generated sample" +## Testing code on datasets + +Test on custom dataset using: +``` +python steered_generate_dataset.py --data_fold /path/to/data --condition "your condition" --config configs/diffusion_config.yml +``` + +Please set the flags you need for the generation. + + +## Testing dataset + +We performed experiments on the first 300 images of CelebA-Multimodal dataset. These can be downloaded from +``` +https://github.com/IIGROUP/MM-CelebA-HQ-Dataset +``` + +## Citation +If you use our work, please use the following citation + +```bibTex +@inproceedings{nair2023steered, + title={Steered Diffusion: A Generalized Framework for Plug-and-Play Conditional Image Synthesis}, + author={Nair, Nithin Gopalakrishnan and Cherian, Anoop and Lohit, Suhas and Wang, Ye and Koike-Akino, Toshiaki and Patel, Vishal M and Marks, Tim K}, + booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, + pages={20850--20860}, + year={2023} +} +``` + +## Contributing + +See [CONTRIBUTING.md](CONTRIBUTING.md) for our policy on contributions. + +## License + +Released under `AGPL-3.0-or-later` license, as found in the [LICENSE.md](LICENSE.md) file. + +All files, except as noted below: +``` +Copyright (c) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) + +SPDX-License-Identifier: AGPL-3.0-or-later +``` + +The following files: + +* `guided_diffusion/guided_diffusion/__init__.py`,`guided_diffusion/guided_diffusion/fp16_util.py` +* `guided_diffusion/guided_diffusion/logger.py`,`guided_diffusion/guided_diffusion/nn.py` +* `guided_diffusion/guided_diffusion/respace.py`,`guided_diffusion/guided_diffusion/script_util.py` +* `guided_diffusion/guided_diffusion/unet.py` + +were taken without modification from https://github.com/openai/guided-diffusion (license included in [LICENSES/MIT.txt](LICENSES/MIT.txt)): + +``` +Copyright (c) 2021 OpenAI +``` + +The following files: + +* `losses/ssim.py` + +were taken without modification from https://github.com/Po-Hsun-Su/pytorch-ssim/tree/master (license included in [LICENSES/MIT.txt](LICENSES/MIT.txt)): + +``` +Copyright (c) 2021 Po-Hsun-Su +``` + +The following files: + +* `guided_diffusion/guided_diffusion/interp_methods.py` + +were taken without modification from https://github.com/assafshocher/ResizeRight/blob/master/interp_methods.py (license included in [LICENSES/MIT.txt](LICENSES/MIT.txt)): + +``` +Copyright (c) 2020 Assaf Shocher +``` + +The following files + +* `guided_diffusion/guided_diffusion/resize_right.py` + +were adapted from https://github.com/assafshocher/ResizeRight/blob/master/resize_right.py (license included in [LICENSES/MIT.txt](LICENSES/MIT.txt)): + +``` +Copyright (c) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +Copyright (c) 2020 Assaf Shocher +``` + +The following files +* `guided_diffusion/guided_diffusion/gaussian_diffusion.py` + +were adapted from https://github.com/openai/guided-diffusion (license included in [LICENSES/MIT.txt](LICENSES/MIT.txt)): + +``` +Copyright (c) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +Copyright (c) 2021 OpenAI +``` + +The following files +* `steered_diffusion.py`,`steered_diffusion_dataset.py`,`parser.py` + +were adapted from https://github.com/omriav/blended-diffusion (license included in [LICENSES/MIT.txt](LICENSES/MIT.txt)): + +``` +Copyright (c) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +Copyright (C) 2022-2023 Omri Avrahami +``` + +The following files +* `utils/download_models.py`,`utils/download_models_func.py` + +were adapted from https://github.com/Nithin-GK/UniteandConquer/blob/main/download_models.py (license included in [LICENSES/Apache-2.0.txt](LICENSES/Apache-2.0.txt)): + +``` +# Copyright (C) 2022-2023 Nithin Gopalakrishnan Nair +``` diff --git a/configs/diffusion_config.yml b/configs/diffusion_config.yml new file mode 100644 index 0000000..90628b4 --- /dev/null +++ b/configs/diffusion_config.yml @@ -0,0 +1,107 @@ +# Copyright (c) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +checkpoints: + arcface: checkpoints/arc face18.pth + faceparse: checkpoints/face_parse.pth + ffhq: checkpoints/ffhq_10m.pt + vggface: checkpoints/VGG_FACE.pth + farlclip: checkpoints/FaRL-Base-Patch16-LAIONFace20M-ep64.pth #./checkpoints/FaRL-Base-Patch16-LAIONFace20M-ep16.pth +data: + init_image: ./input_example/faces/4.jpg + mask_image: ./input_example/masks/4.png + +gpu_id: 0 +manual_seed: 0 +name: Diff_edit +diffusion_network: + attention_resolutions: '16' + class_cond: false + diffusion_steps: 1000 + image_size: 256 + learn_sigma: true + noise_schedule: linear + num_channels: 128 + num_head_channels: 64 + num_res_blocks: 1 + resblock_updown: true + rescale_timesteps: true + timestep_respacing: '100' + use_fp16: false + use_scale_shift_norm: true + + +# diffusion_network: non-face +# attention_resolutions: '32,16,8' +# class_cond: false +# diffusion_steps: 1000 +# image_size: 256 +# learn_sigma: true +# noise_schedule: linear +# num_channels: 256 +# num_head_channels: 64 +# num_res_blocks: 2 +# resblock_updown: true +# rescale_timesteps: true +# timestep_respacing: '100' +# use_fp16: True +# use_scale_shift_norm: true +num_gpu: 1 + + +networks: + VGGface: + checkpoint: ./checkpoints/VGG_FACE.pth + multiscale: + use: false + min_t: 0 + max_t: 100 + lambda: 10 + singlescale: + use: false + min_t: 0 + max_t: 100 + lambda: 10 + + + Semantics: + criterion: nn.BCEWithLogitsLoss + face_segment_parse: + use: false + min_t: 0 + max_t: 100 + lambda: 30000 + + FARL: + checkpoint: ./checkpoints/FaRL-Base-Patch16-LAIONFace20M-ep64.pth + farlclip: + use: false + min_t: 0 + max_t: 90 + lambda: 1500 + prompt: A woman with red hair + farledit: + use: false + min_t: 0 + max_t: 85 + lambda: 1500 + prompt: Red hair + farlidentity: + use: false + min_t: 0 + max_t: 100 + lambda: 3000 + + + +params: + batch_size: 1 + image_size: 256 + cond: Semantics + scale_factor: 4 + use_ddim: false + results_dir: ./results + data_path_fold: ./data + +seed: 404 diff --git a/configs/diffusion_config_imagenet.yml b/configs/diffusion_config_imagenet.yml new file mode 100644 index 0000000..cd4fe8b --- /dev/null +++ b/configs/diffusion_config_imagenet.yml @@ -0,0 +1,88 @@ +# Copyright (c) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +checkpoints: + arcface: checkpoints/arc face18.pth + faceparse: checkpoints/face_parse.pth + ffhq: checkpoints/diffusion256x256.pt + vggface: checkpoints/VGG_FACE.pth + farlclip: checkpoints/FaRL-Base-Patch16-LAIONFace20M-ep64.pth #./checkpoints/FaRL-Base-Patch16-LAIONFace20M-ep16.pth +data: + init_image: ./input_example/faces/4.jpg + mask_image: ./input_example/masks/4.png + +gpu_id: 0 +manual_seed: 0 +name: Diff_edit + + +diffusion_network: + attention_resolutions: '32,16,8' + class_cond: false + diffusion_steps: 1000 + image_size: 256 + learn_sigma: true + noise_schedule: linear + num_channels: 256 + num_head_channels: 64 + num_res_blocks: 2 + resblock_updown: true + rescale_timesteps: true + timestep_respacing: '100' + use_fp16: True + use_scale_shift_norm: true + +num_gpu: 1 + + +networks: + VGGface: + checkpoint: ./checkpoints/VGG_FACE.pth + multiscale: + use: false + min_t: 0 + max_t: 100 + lambda: 10 + + + Semantics: + criterion: nn.BCEWithLogitsLoss + face_segment_parse: + use: false + min_t: 0 + max_t: 100 + lambda: 20000 + + FARL: + checkpoint: ./checkpoints/FaRL-Base-Patch16-LAIONFace20M-ep64.pth + farlclip: + use: false + min_t: 0 + max_t: 90 + lambda: 1500 + prompt: A woman with blonde hair + farledit: + use: false + min_t: 0 + max_t: 90 + lambda: 100 + prompt: Red hair + farlidentity: + use: false + min_t: 0 + max_t: 100 + lambda: 3000 + + + +params: + batch_size: 1 + image_size: 256 + cond: Semantics + scale_factor: 4 + use_ddim: false + results_dir: ./results + data_path_fold: ./data + +seed: 404 diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..8875e4f --- /dev/null +++ b/environment.yml @@ -0,0 +1,45 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +name: SteeredDiffusion +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - cffi=1.15.0=py37h7f8727e_0 + - cudatoolkit=11.3.1=h2bc3f7f_2 + - git=2.19.1=pl526h7fee0ce_0 + - ipython=7.33.0=py37h89c1867_0 + - pip=22.3.1=pyhd8ed1ab_0 + - python=3.7.0=h6e4f718_3 + - pytorch=1.11.0=py3.7_cuda11.3_cudnn8.2.0_0 + - urllib3=1.26.9=py37h06a4308_0 + - pip: + - git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33 + - easydict==1.10 + - einops==0.6.0 + - git+https://github.com/FacePerceiver/facer.git@b2e4fcb94fa8db7cff1a053b6c6441131274c074 + - huggingface-hub==0.12.1 + - kornia==0.5.0 + - lpips==0.1.4 + - matplotlib==3.5.3 + - numpy==1.21.6 + - omegaconf==2.1.2 + - open-clip-torch==2.7.0 + - opencv-python==4.6.0.66 + - pandas==1.1.5 + - pillow==9.1.1 + - pyyaml==6.0 + - scikit-image==0.19.3 + - scikit-learn==1.0.2 + - scipy==1.7.3 + - timm==0.9.12 + - tokenizers==0.12.1 + - torchmetrics==0.11.0 + - torchvision==0.12.0 + - tqdm==4.64.1 + - transformers==4.19.2 + - webdataset==0.2.31 + - yarl==1.8.2 diff --git a/guided_diffusion/.gitignore b/guided_diffusion/.gitignore new file mode 100644 index 0000000..c0c18e5 --- /dev/null +++ b/guided_diffusion/.gitignore @@ -0,0 +1,6 @@ +# Copyright (c) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +.DS_Store +__pycache__/ diff --git a/guided_diffusion/guided_diffusion/__init__.py b/guided_diffusion/guided_diffusion/__init__.py new file mode 100644 index 0000000..f5e2071 --- /dev/null +++ b/guided_diffusion/guided_diffusion/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +""" +Codebase for "Improved Denoising Diffusion Probabilistic Models". +""" diff --git a/guided_diffusion/guided_diffusion/fp16_util.py b/guided_diffusion/guided_diffusion/fp16_util.py new file mode 100644 index 0000000..38f3496 --- /dev/null +++ b/guided_diffusion/guided_diffusion/fp16_util.py @@ -0,0 +1,226 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# Copyright (c) 2021 OpenAI + +# SPDX-License-Identifier: AGPL-3.0-or-later +# SPDX-License-Identifier: MIT + +# Code taken from https://github.com/openai/guided-diffusion/tree/main/guided_diffusion -- MIT License + +import numpy as np +import torch as th +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from . import logger + +INITIAL_LOG_LOSS_SCALE = 20.0 + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + if l.bias is not None: + l.bias.data = l.bias.data.float() + + +def make_master_params(param_groups_and_shapes): + """ + Copy model parameters into a (differently-shaped) list of full-precision + parameters. + """ + master_params = [] + for param_group, shape in param_groups_and_shapes: + master_param = nn.Parameter( + _flatten_dense_tensors([param.detach().float() for (_, param) in param_group]).view(shape) + ) + master_param.requires_grad = True + master_params.append(master_param) + return master_params + + +def model_grads_to_master_grads(param_groups_and_shapes, master_params): + """ + Copy the gradients from the model parameters into the master parameters + from make_master_params(). + """ + for master_param, (param_group, shape) in zip(master_params, param_groups_and_shapes): + master_param.grad = _flatten_dense_tensors([param_grad_or_zeros(param) for (_, param) in param_group]).view( + shape + ) + + +def master_params_to_model_params(param_groups_and_shapes, master_params): + """ + Copy the master parameter data back into the model parameters. + """ + # Without copying to a list, if a generator is passed, this will + # silently not copy any parameters. + for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): + for (_, param), unflat_master_param in zip( + param_group, unflatten_master_params(param_group, master_param.view(-1)) + ): + param.detach().copy_(unflat_master_param) + + +def unflatten_master_params(param_group, master_param): + return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) + + +def get_param_groups_and_shapes(named_model_params): + named_model_params = list(named_model_params) + scalar_vector_named_params = ( + [(n, p) for (n, p) in named_model_params if p.ndim <= 1], + (-1), + ) + matrix_named_params = ( + [(n, p) for (n, p) in named_model_params if p.ndim > 1], + (1, -1), + ) + return [scalar_vector_named_params, matrix_named_params] + + +def master_params_to_state_dict(model, param_groups_and_shapes, master_params, use_fp16): + if use_fp16: + state_dict = model.state_dict() + for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): + for (name, _), unflat_master_param in zip( + param_group, unflatten_master_params(param_group, master_param.view(-1)) + ): + assert name in state_dict + state_dict[name] = unflat_master_param + else: + state_dict = model.state_dict() + for i, (name, _value) in enumerate(model.named_parameters()): + assert name in state_dict + state_dict[name] = master_params[i] + return state_dict + + +def state_dict_to_master_params(model, state_dict, use_fp16): + if use_fp16: + named_model_params = [(name, state_dict[name]) for name, _ in model.named_parameters()] + param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) + master_params = make_master_params(param_groups_and_shapes) + else: + master_params = [state_dict[name] for name, _ in model.named_parameters()] + return master_params + + +def zero_master_grads(master_params): + for param in master_params: + param.grad = None + + +def zero_grad(model_params): + for param in model_params: + # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group + if param.grad is not None: + param.grad.detach_() + param.grad.zero_() + + +def param_grad_or_zeros(param): + if param.grad is not None: + return param.grad.data.detach() + else: + return th.zeros_like(param) + + +class MixedPrecisionTrainer: + def __init__( + self, + *, + model, + use_fp16=False, + fp16_scale_growth=1e-3, + initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, + ): + self.model = model + self.use_fp16 = use_fp16 + self.fp16_scale_growth = fp16_scale_growth + + self.model_params = list(self.model.parameters()) + self.master_params = self.model_params + self.param_groups_and_shapes = None + self.lg_loss_scale = initial_lg_loss_scale + + if self.use_fp16: + self.param_groups_and_shapes = get_param_groups_and_shapes(self.model.named_parameters()) + self.master_params = make_master_params(self.param_groups_and_shapes) + self.model.convert_to_fp16() + + def zero_grad(self): + zero_grad(self.model_params) + + def backward(self, loss: th.Tensor): + if self.use_fp16: + loss_scale = 2**self.lg_loss_scale + (loss * loss_scale).backward() + else: + loss.backward() + + def optimize(self, opt: th.optim.Optimizer): + if self.use_fp16: + return self._optimize_fp16(opt) + else: + return self._optimize_normal(opt) + + def _optimize_fp16(self, opt: th.optim.Optimizer): + logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) + model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) + grad_norm, param_norm = self._compute_norms(grad_scale=2**self.lg_loss_scale) + if check_overflow(grad_norm): + self.lg_loss_scale -= 1 + logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") + zero_master_grads(self.master_params) + return False + + logger.logkv_mean("grad_norm", grad_norm) + logger.logkv_mean("param_norm", param_norm) + + self.master_params[0].grad.mul_(1.0 / (2**self.lg_loss_scale)) + opt.step() + zero_master_grads(self.master_params) + master_params_to_model_params(self.param_groups_and_shapes, self.master_params) + self.lg_loss_scale += self.fp16_scale_growth + return True + + def _optimize_normal(self, opt: th.optim.Optimizer): + grad_norm, param_norm = self._compute_norms() + logger.logkv_mean("grad_norm", grad_norm) + logger.logkv_mean("param_norm", param_norm) + opt.step() + return True + + def _compute_norms(self, grad_scale=1.0): + grad_norm = 0.0 + param_norm = 0.0 + for p in self.master_params: + with th.no_grad(): + param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 + if p.grad is not None: + grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 + return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) + + def master_params_to_state_dict(self, master_params): + return master_params_to_state_dict(self.model, self.param_groups_and_shapes, master_params, self.use_fp16) + + def state_dict_to_master_params(self, state_dict): + return state_dict_to_master_params(self.model, state_dict, self.use_fp16) + + +def check_overflow(value): + return (value == float("inf")) or (value == -float("inf")) or (value != value) diff --git a/guided_diffusion/guided_diffusion/gaussian_diffusion.py b/guided_diffusion/guided_diffusion/gaussian_diffusion.py new file mode 100644 index 0000000..684a36c --- /dev/null +++ b/guided_diffusion/guided_diffusion/gaussian_diffusion.py @@ -0,0 +1,699 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# Copyright (c) 2021 OpenAI + +# SPDX-License-Identifier: AGPL-3.0-or-later +# SPDX-License-Identifier: MIT + +# Code adapted from https://github.com/openai/guided-diffusion/tree/main/guided_diffusion -- MIT License + + +""" +This code started out as a PyTorch port of Ho et al's diffusion models: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py + +Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. +""" + +import enum +import math +from random import sample + +import kornia +import numpy as np +import torch as th + +from .nn import mean_flat +from .resize_right import Resizer + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + ): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) + self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + else: + raise NotImplementedError(self.model_mean_type) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + noise = th.randn_like(x) + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + sample = model_mean + nonzero_mask * th.exp(0.5 * model_log_variance) * noise + # "sample": + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "sample": sample, + "model_output": model_output, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t + ) + + def _predict_xprev_from_xstart(self, x_t, t, x_start): + assert x_t.shape == x_start.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def conditional_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + postprocess_fn=None, + randomize_class=False, + ): + + final = None + # print + for sample in self.conditional_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + postprocess_fn=postprocess_fn, + randomize_class=randomize_class, + ): + final = sample + return final["sample"] + + def conditional_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + postprocess_fn=None, + randomize_class=False, + dest_fold=None, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + init_image = model_kwargs["init_image"] + noise = None + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + + indices = list(range(self.num_timesteps))[::-1] + + init_image_batch = img + + img = self.q_sample( + x_start=init_image_batch, + t=th.tensor(indices[0], dtype=th.long, device=device), + noise=img, + ) + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + + with th.no_grad(): + out = self.conditional_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + + yield out + img = out["sample"] + + def pred_sample(self, x, t, start): + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + eps = self._predict_eps_from_xstart(x, t, start) + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + eta = 1 + + sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) + # Equation 12. + noise = th.randn_like(x) + mean_pred = start * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + + sample = mean_pred + nonzero_mask * sigma * noise + + return sample + + @th.no_grad() + def conditional_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=1, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + new_kwargs = {} + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + for _ in range(model_kwargs["num_iters"]): + + shape = x.shape + scale = np.sqrt(1 - self.alphas_cumprod_next[t]) # + + cond = model_kwargs["cond"] + # print + if "grayscale" in cond: + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs={}, + ) + + init_image = th.clone(model_kwargs["init_image"]) # + in_samp_ycbcr = kornia.color.rgb_to_ycbcr(init_image) + out_samp_ycbcr = kornia.color.rgb_to_ycbcr(out["pred_xstart"]) + scale = 1 + out_samp_ycbcr[:, 0, :, :] = in_samp_ycbcr[:, 0, :, :] + start = kornia.color.ycbcr_to_rgb(out_samp_ycbcr) + degraded = in_samp_ycbcr[:, 0, :, :].repeat(1, 3, 1, 1) + + elif "SR" in cond: + factor = model_kwargs["factor"] + shape_d = (shape[0], shape[1], shape[2] // factor, shape[3] // factor) + up = Resizer(in_shape=shape_d, scale_factors=factor, interp_method="cubic").to( + next(model.parameters()).device + ) + down = Resizer(in_shape=shape, scale_factors=1 / factor, interp_method="cubic").to( + next(model.parameters()).device + ) + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs={}, + ) + init_image = th.clone(model_kwargs["init_image"]) # + model_kwargs["num_iters"] = 1 + start = out["pred_xstart"] - (up(down(out["pred_xstart"])) - up(down(init_image))) + degraded = up(down(init_image)) + + elif "inpaint" in cond: + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs={}, + ) + init_image_mask = 1 - th.clone(model_kwargs["mask_image"]) + init_image = th.clone(model_kwargs["init_image"]) + start = out["pred_xstart"] + (init_image * (init_image_mask) - out["pred_xstart"] * init_image_mask) + degraded = init_image * (init_image_mask) + else: + with th.enable_grad(): + x = x.detach().requires_grad_() + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs={}, + ) + cond_input = out["pred_xstart"] + init_image = th.clone(model_kwargs["init_image"]) # + loss_new = cond_fn(cond_input, init_image, t) + grad_val_new = -th.autograd.grad(loss_new, x, allow_unused=True)[0] + start = out["pred_xstart"] + scale * grad_val_new + degraded = init_image + + noise1 = th.randn_like(x) + sample = self.pred_sample(x, t, start) + x = sample * (np.sqrt(1 - self.betas[t])) + self.betas[t] * noise1 + + return {"sample": sample, "pred_xstart": out["pred_xstart"], "degraded": degraded} + + def ddim_reverse_sample_loop( + self, + model, + x, + cond_fn=None, + clip_denoised=False, + denoised_fn=None, + model_kwargs=None, + eta=0, + device=None, + ): + if device is None: + device = next(model.parameters()).device + sample_t = [] + xstart_t = [] + T = [] + indices = list(range(self.num_timesteps)) + sample = x + for i in indices: + t = th.tensor([i] * len(sample), device=device) + with th.no_grad(): + out = self.ddim_reverse_sample( + model, + sample, + t=t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + sample = out["sample"] + # [1, ..., T] + sample_t.append(sample) + # [0, ...., T-1] + xstart_t.append(out["pred_xstart"]) + # [0, ..., T-1] ready to use + T.append(t) + + return { + # xT " + "sample": sample, + # (1, ..., T) + "sample_t": sample_t, + # xstart here is a bit different from sampling from T = T-1 to T = 0 + # may not be exact + "xstart_t": xstart_t, + "T": T, + } + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) diff --git a/guided_diffusion/guided_diffusion/interp_methods.py b/guided_diffusion/guided_diffusion/interp_methods.py new file mode 100644 index 0000000..a8db028 --- /dev/null +++ b/guided_diffusion/guided_diffusion/interp_methods.py @@ -0,0 +1,74 @@ +# Copyright (c) 2020 Assaf Shocher +# SPDX-License-Identifier: MIT + +# Code adapted from https://github.com/assafshocher/ResizeRight/blob/master/interp_methods.py -- MIT License + + +from math import pi + +try: + import torch +except ImportError: + torch = None + +try: + import numpy +except ImportError: + numpy = None + +if numpy is None and torch is None: + raise ImportError("Must have either Numpy or PyTorch but both not found") + + +def set_framework_dependencies(x): + if type(x) is numpy.ndarray: + to_dtype = lambda a: a + fw = numpy + else: + to_dtype = lambda a: a.to(x.dtype) + fw = torch + eps = fw.finfo(fw.float32).eps + return fw, to_dtype, eps + + +def support_sz(sz): + def wrapper(f): + f.support_sz = sz + return f + + return wrapper + + +@support_sz(4) +def cubic(x): + fw, to_dtype, eps = set_framework_dependencies(x) + absx = fw.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1.0) * to_dtype(absx <= 1.0) + ( + -0.5 * absx3 + 2.5 * absx2 - 4.0 * absx + 2.0 + ) * to_dtype((1.0 < absx) & (absx <= 2.0)) + + +@support_sz(4) +def lanczos2(x): + fw, to_dtype, eps = set_framework_dependencies(x) + return ((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) / ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2) + + +@support_sz(6) +def lanczos3(x): + fw, to_dtype, eps = set_framework_dependencies(x) + return ((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) / ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3) + + +@support_sz(2) +def linear(x): + fw, to_dtype, eps = set_framework_dependencies(x) + return (x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) * to_dtype((0 <= x) & (x <= 1)) + + +@support_sz(1) +def box(x): + fw, to_dtype, eps = set_framework_dependencies(x) + return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1)) diff --git a/guided_diffusion/guided_diffusion/logger.py b/guided_diffusion/guided_diffusion/logger.py new file mode 100644 index 0000000..4850dc6 --- /dev/null +++ b/guided_diffusion/guided_diffusion/logger.py @@ -0,0 +1,485 @@ +# Copyright (c) 2021 OpenAI +# SPDX-License-Identifier: MIT + +# Code adapted from https://github.com/openai/guided-diffusion/tree/main/guided_diffusion -- MIT License + +""" +Logger copied from OpenAI baselines to avoid extra RL-based dependencies: +https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py +""" + +import datetime +import json +import os +import os.path as osp +import shutil +import sys +import tempfile +import time +import warnings +from collections import defaultdict +from contextlib import contextmanager + +DEBUG = 10 +INFO = 20 +WARN = 30 +ERROR = 40 + +DISABLED = 50 + + +class KVWriter(object): + def writekvs(self, kvs): + raise NotImplementedError + + +class SeqWriter(object): + def writeseq(self, seq): + raise NotImplementedError + + +class HumanOutputFormat(KVWriter, SeqWriter): + def __init__(self, filename_or_file): + if isinstance(filename_or_file, str): + self.file = open(filename_or_file, "wt") + self.own_file = True + else: + assert hasattr(filename_or_file, "read"), "expected file or str, got %s" % filename_or_file + self.file = filename_or_file + self.own_file = False + + def writekvs(self, kvs): + # Create strings for printing + key2str = {} + for (key, val) in sorted(kvs.items()): + if hasattr(val, "__float__"): + valstr = "%-8.3g" % val + else: + valstr = str(val) + key2str[self._truncate(key)] = self._truncate(valstr) + + # Find max widths + if len(key2str) == 0: + print("WARNING: tried to write empty key-value dict") + return + else: + keywidth = max(map(len, key2str.keys())) + valwidth = max(map(len, key2str.values())) + + # Write out the data + dashes = "-" * (keywidth + valwidth + 7) + lines = [dashes] + for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): + lines.append("| %s%s | %s%s |" % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))) + lines.append(dashes) + self.file.write("\n".join(lines) + "\n") + + # Flush the output to the file + self.file.flush() + + def _truncate(self, s): + maxlen = 30 + return s[: maxlen - 3] + "..." if len(s) > maxlen else s + + def writeseq(self, seq): + seq = list(seq) + for (i, elem) in enumerate(seq): + self.file.write(elem) + if i < len(seq) - 1: # add space unless this is the last one + self.file.write(" ") + self.file.write("\n") + self.file.flush() + + def close(self): + if self.own_file: + self.file.close() + + +class JSONOutputFormat(KVWriter): + def __init__(self, filename): + self.file = open(filename, "wt") + + def writekvs(self, kvs): + for k, v in sorted(kvs.items()): + if hasattr(v, "dtype"): + kvs[k] = float(v) + self.file.write(json.dumps(kvs) + "\n") + self.file.flush() + + def close(self): + self.file.close() + + +class CSVOutputFormat(KVWriter): + def __init__(self, filename): + self.file = open(filename, "w+t") + self.keys = [] + self.sep = "," + + def writekvs(self, kvs): + # Add our current row to the history + extra_keys = list(kvs.keys() - self.keys) + extra_keys.sort() + if extra_keys: + self.keys.extend(extra_keys) + self.file.seek(0) + lines = self.file.readlines() + self.file.seek(0) + for (i, k) in enumerate(self.keys): + if i > 0: + self.file.write(",") + self.file.write(k) + self.file.write("\n") + for line in lines[1:]: + self.file.write(line[:-1]) + self.file.write(self.sep * len(extra_keys)) + self.file.write("\n") + for (i, k) in enumerate(self.keys): + if i > 0: + self.file.write(",") + v = kvs.get(k) + if v is not None: + self.file.write(str(v)) + self.file.write("\n") + self.file.flush() + + def close(self): + self.file.close() + + +class TensorBoardOutputFormat(KVWriter): + """ + Dumps key/value pairs into TensorBoard's numeric format. + """ + + def __init__(self, dir): + os.makedirs(dir, exist_ok=True) + self.dir = dir + self.step = 1 + prefix = "events" + path = osp.join(osp.abspath(dir), prefix) + import tensorflow as tf + from tensorflow.core.util import event_pb2 + from tensorflow.python import pywrap_tensorflow + from tensorflow.python.util import compat + + self.tf = tf + self.event_pb2 = event_pb2 + self.pywrap_tensorflow = pywrap_tensorflow + self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) + + def writekvs(self, kvs): + def summary_val(k, v): + kwargs = {"tag": k, "simple_value": float(v)} + return self.tf.Summary.Value(**kwargs) + + summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) + event = self.event_pb2.Event(wall_time=time.time(), summary=summary) + event.step = self.step # is there any reason why you'd want to specify the step? + self.writer.WriteEvent(event) + self.writer.Flush() + self.step += 1 + + def close(self): + if self.writer: + self.writer.Close() + self.writer = None + + +def make_output_format(format, ev_dir, log_suffix=""): + os.makedirs(ev_dir, exist_ok=True) + if format == "stdout": + return HumanOutputFormat(sys.stdout) + elif format == "log": + return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) + elif format == "json": + return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) + elif format == "csv": + return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) + elif format == "tensorboard": + return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) + else: + raise ValueError("Unknown format specified: %s" % (format,)) + + +# ================================================================ +# API +# ================================================================ + + +def logkv(key, val): + """ + Log a value of some diagnostic + Call this once for each diagnostic quantity, each iteration + If called many times, last value will be used. + """ + get_current().logkv(key, val) + + +def logkv_mean(key, val): + """ + The same as logkv(), but if called many times, values averaged. + """ + get_current().logkv_mean(key, val) + + +def logkvs(d): + """ + Log a dictionary of key-value pairs + """ + for (k, v) in d.items(): + logkv(k, v) + + +def dumpkvs(): + """ + Write all of the diagnostics from the current iteration + """ + return get_current().dumpkvs() + + +def getkvs(): + return get_current().name2val + + +def log(*args, level=INFO): + """ + Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). + """ + get_current().log(*args, level=level) + + +def debug(*args): + log(*args, level=DEBUG) + + +def info(*args): + log(*args, level=INFO) + + +def warn(*args): + log(*args, level=WARN) + + +def error(*args): + log(*args, level=ERROR) + + +def set_level(level): + """ + Set logging threshold on current logger. + """ + get_current().set_level(level) + + +def set_comm(comm): + get_current().set_comm(comm) + + +def get_dir(): + """ + Get directory that log files are being written to. + will be None if there is no output directory (i.e., if you didn't call start) + """ + return get_current().get_dir() + + +record_tabular = logkv +dump_tabular = dumpkvs + + +@contextmanager +def profile_kv(scopename): + logkey = "wait_" + scopename + tstart = time.time() + try: + yield + finally: + get_current().name2val[logkey] += time.time() - tstart + + +def profile(n): + """ + Usage: + @profile("my_func") + def my_func(): code + """ + + def decorator_with_name(func): + def func_wrapper(*args, **kwargs): + with profile_kv(n): + return func(*args, **kwargs) + + return func_wrapper + + return decorator_with_name + + +# ================================================================ +# Backend +# ================================================================ + + +def get_current(): + if Logger.CURRENT is None: + _configure_default_logger() + + return Logger.CURRENT + + +class Logger(object): + DEFAULT = None # A logger with no output files. (See right below class definition) + # So that you can still log to the terminal without setting up any output files + CURRENT = None # Current logger being used by the free functions above + + def __init__(self, dir, output_formats, comm=None): + self.name2val = defaultdict(float) # values this iteration + self.name2cnt = defaultdict(int) + self.level = INFO + self.dir = dir + self.output_formats = output_formats + self.comm = comm + + # Logging API, forwarded + # ---------------------------------------- + def logkv(self, key, val): + self.name2val[key] = val + + def logkv_mean(self, key, val): + oldval, cnt = self.name2val[key], self.name2cnt[key] + self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) + self.name2cnt[key] = cnt + 1 + + def dumpkvs(self): + if self.comm is None: + d = self.name2val + else: + d = mpi_weighted_mean( + self.comm, + {name: (val, self.name2cnt.get(name, 1)) for (name, val) in self.name2val.items()}, + ) + if self.comm.rank != 0: + d["dummy"] = 1 # so we don't get a warning about empty dict + out = d.copy() # Return the dict for unit testing purposes + for fmt in self.output_formats: + if isinstance(fmt, KVWriter): + fmt.writekvs(d) + self.name2val.clear() + self.name2cnt.clear() + return out + + def log(self, *args, level=INFO): + if self.level <= level: + self._do_log(args) + + # Configuration + # ---------------------------------------- + def set_level(self, level): + self.level = level + + def set_comm(self, comm): + self.comm = comm + + def get_dir(self): + return self.dir + + def close(self): + for fmt in self.output_formats: + fmt.close() + + # Misc + # ---------------------------------------- + def _do_log(self, args): + for fmt in self.output_formats: + if isinstance(fmt, SeqWriter): + fmt.writeseq(map(str, args)) + + +def get_rank_without_mpi_import(): + # check environment variables here instead of importing mpi4py + # to avoid calling MPI_Init() when this module is imported + for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: + if varname in os.environ: + return int(os.environ[varname]) + return 0 + + +def mpi_weighted_mean(comm, local_name2valcount): + """ + Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 + Perform a weighted average over dicts that are each on a different node + Input: local_name2valcount: dict mapping key -> (value, count) + Returns: key -> mean + """ + all_name2valcount = comm.gather(local_name2valcount) + if comm.rank == 0: + name2sum = defaultdict(float) + name2count = defaultdict(float) + for n2vc in all_name2valcount: + for (name, (val, count)) in n2vc.items(): + try: + val = float(val) + except ValueError: + if comm.rank == 0: + warnings.warn("WARNING: tried to compute mean on non-float {}={}".format(name, val)) + else: + name2sum[name] += val * count + name2count[name] += count + return {name: name2sum[name] / name2count[name] for name in name2sum} + else: + return {} + + +def configure(dir=None, format_strs=None, comm=None, log_suffix=""): + """ + If comm is provided, average all numerical stats across that comm + """ + if dir is None: + dir = os.getenv("OPENAI_LOGDIR") + if dir is None: + dir = osp.join( + tempfile.gettempdir(), + datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), + ) + assert isinstance(dir, str) + dir = os.path.expanduser(dir) + os.makedirs(os.path.expanduser(dir), exist_ok=True) + + rank = get_rank_without_mpi_import() + if rank > 0: + log_suffix = log_suffix + "-rank%03i" % rank + + if format_strs is None: + if rank == 0: + format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") + else: + format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") + format_strs = filter(None, format_strs) + output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] + + Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) + if output_formats: + log("Logging to %s" % dir) + + +def _configure_default_logger(): + configure() + Logger.DEFAULT = Logger.CURRENT + + +def reset(): + if Logger.CURRENT is not Logger.DEFAULT: + Logger.CURRENT.close() + Logger.CURRENT = Logger.DEFAULT + log("Reset logger") + + +@contextmanager +def scoped_configure(dir=None, format_strs=None, comm=None): + prevlogger = Logger.CURRENT + configure(dir=dir, format_strs=format_strs, comm=comm) + try: + yield + finally: + Logger.CURRENT.close() + Logger.CURRENT = prevlogger diff --git a/guided_diffusion/guided_diffusion/nn.py b/guided_diffusion/guided_diffusion/nn.py new file mode 100644 index 0000000..6bf2814 --- /dev/null +++ b/guided_diffusion/guided_diffusion/nn.py @@ -0,0 +1,173 @@ +# Copyright (c) 2021 OpenAI +# SPDX-License-Identifier: MIT + +# Code adapted from https://github.com/openai/guided-diffusion/tree/main/guided_diffusion -- MIT License +""" +Various utilities for neural networks. +""" +import math + +import torch as th +import torch.nn as nn + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * th.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) + if dim % 2: + embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(th.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + with th.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with th.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = th.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads diff --git a/guided_diffusion/guided_diffusion/resize_right.py b/guided_diffusion/guided_diffusion/resize_right.py new file mode 100644 index 0000000..de3001d --- /dev/null +++ b/guided_diffusion/guided_diffusion/resize_right.py @@ -0,0 +1,467 @@ +# Copyright (c) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# Copyright (c) 2020 Assaf Shocher + +# SPDX-License-Identifier: AGPL-3.0-or-later +# SPDX-License-Identifier: MIT + +# Code adapted from https://github.com/assafshocher/ResizeRight/blob/master/resize_right.py -- MIT License + + +import warnings +from fractions import Fraction +from math import ceil +from typing import Tuple + +import numpy as np + +from . import interp_methods +from .interp_methods import box, cubic, lanczos2, lanczos3, linear, support_sz + + +class NoneClass: + pass + + +try: + import torch + from torch import nn + + nnModuleWrapped = nn.Module +except ImportError: + warnings.warn("No PyTorch found, will work only with Numpy") + torch = None + nnModuleWrapped = NoneClass + +try: + import numpy +except ImportError: + warnings.warn("No Numpy found, will work only with PyTorch") + numpy = None + + +if numpy is None and torch is None: + raise ImportError("Must have either Numpy or PyTorch but both not found") + + +class Resizer(nn.Module): + def __init__( + self, + in_shape, + scale_factors=None, + out_shape=None, + interp_method="cubic", + support_sz=None, + antialiasing=True, + by_convs=False, + scale_tolerance=None, + max_numerator=10, + pad_mode="constant", + ): + super(Resizer, self).__init__() + + # get properties of the input tensor + # in_shape, n_dims = input.shape, input.ndim + interpmethods = { + "cubic": cubic, + "lanczos2": lanczos2, + "lanczos3": lanczos3, + "linear": linear, + "box": box, + None: cubic, + } + + interp_method = interpmethods[interp_method] + # fw stands for framework that can be either numpy or torch, + # determined by the input type + n_dims = 4 + fw = numpy if type(input) is numpy.ndarray else torch + eps = fw.finfo(fw.float32).eps + device = "cuda" + self.fw = fw + self.pad_mode = pad_mode + # set missing scale factors or output shapem one according to another, + # scream if both missing. this is also where all the defults policies + # take place. also handling the by_convs attribute carefully. + scale_factors, out_shape, by_convs = self.set_scale_and_out_sz( + in_shape, out_shape, scale_factors, by_convs, scale_tolerance, max_numerator, eps, fw + ) + # sort indices of dimensions according to scale of each dimension. + # since we are going dim by dim this is efficient + self.sorted_filtered_dims_and_scales = [ + (dim, scale_factors[dim], by_convs[dim], in_shape[dim], out_shape[dim]) + for dim in sorted(range(n_dims), key=lambda ind: scale_factors[ind]) + if scale_factors[dim] != 1.0 + ] + + # unless support size is specified by the user, it is an attribute + # of the interpolation method + if support_sz is None: + support_sz = interp_method.support_sz + + # output begins identical to input and changes with each iteration + output = input + self.projected_grids = [] + self.cur_interp_methods = [] + self.cur_support_szs = [] + self.fov = [] + self.weights = [] + self.pad_sz = [] + self.projected_grid = [] + self.dim = [] + self.scale_factor = [] + self.dim_by_convs = [] + self.in_sz = [] + self.out_sz = [] + # iterate over dims + # print(self.sorted_filtered_dims_and_scales) + for (dim, scale_factor, dim_by_convs, in_sz, out_sz) in self.sorted_filtered_dims_and_scales: + self.dim.append(dim) + self.scale_factor.append(scale_factor) + self.dim_by_convs.append(dim_by_convs) + self.in_sz.append(in_sz) + self.out_sz.append(out_sz) + # STEP 1- PROJECTED GRID: The non-integer locations of the projection + # of output pixel locations to the input tensor + projected_grid = self.get_projected_grid(in_sz, out_sz, scale_factor, fw, dim_by_convs, device) + + # STEP 1.5: ANTIALIASING- If antialiasing is taking place, we modify + # the window size and the interpolation method (see inside function) + cur_interp_method, cur_support_sz = self.apply_antialiasing_if_needed( + interp_method, support_sz, scale_factor, antialiasing + ) + self.cur_interp_methods.append(cur_interp_method) + self.cur_support_szs.append(cur_support_sz) + # STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels + # that influence it. Also calculate needed padding and update grid + # accoedingly + field_of_view = self.get_field_of_view(projected_grid, cur_support_sz, fw, eps, device) + + # STEP 2.5- CALCULATE PAD AND UPDATE: according to the field of view, + # the input should be padded to handle the boundaries, coordinates + # should be updated. actual padding only occurs when weights are + # aplied (step 4). if using by_convs for this dim, then we need to + # calc right and left boundaries for each filter instead. + pad_sz, projected_grid, field_of_view = self.calc_pad_sz( + in_sz, out_sz, field_of_view, projected_grid, scale_factor, dim_by_convs, fw, device + ) + self.pad_sz.append(pad_sz) + self.projected_grids.append(projected_grid) + self.fov.append(field_of_view) + + # STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in + # the field of view for each output pixel + weights = self.get_weights(cur_interp_method, projected_grid, field_of_view) + self.weights.append(weights) + # STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying + # its set of weights with the pixel values in its field of view. + # We now multiply the fields of view with their matching weights. + # We do this by tensor multiplication and broadcasting. + # if by_convs is true for this dim, then we do this action by + # convolutions. this is equivalent but faster. + + def forward(self, input): + _, n_dims = input.shape, input.ndim + for i in range(len(self.dim)): + field_of_view = self.fov[i] + weights = self.weights[i] + dim = self.dim[i] + scale_factor = self.scale_factor[i] + dim_by_convs = self.dim_by_convs[i] + in_sz = self.in_sz[i] + out_sz = self.out_sz[i] + pad_sz = self.pad_sz[i] + + if not dim_by_convs: + output = self.apply_weights(input, field_of_view, weights, dim, n_dims, pad_sz, self.pad_mode, self.fw) + else: + output = self.apply_convs(input, scale_factor, in_sz, out_sz, weights, dim, pad_sz, self.pad_mode, self.fw) + return output + + def get_projected_grid(self, in_sz, out_sz, scale_factor, fw, by_convs, device=None): + # we start by having the ouput coordinates which are just integer locations + # in the special case when usin by_convs, we only need two cycles of grid + # points. the first and last. + grid_sz = out_sz if not by_convs else scale_factor.numerator + out_coordinates = fw_arange(grid_sz, fw, device) + + # This is projecting the ouput pixel locations in 1d to the input tensor, + # as non-integer locations. + # the following fomrula is derived in the paper + # "From Discrete to Continuous Convolutions" by Shocher et al. + return out_coordinates / float(scale_factor) + (in_sz - 1) / 2 - (out_sz - 1) / (2 * float(scale_factor)) + + def get_field_of_view(self, projected_grid, cur_support_sz, fw, eps, device): + # for each output pixel, map which input pixels influence it, in 1d. + # we start by calculating the leftmost neighbor, using half of the window + # size (eps is for when boundary is exact int) + # print() + left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw) + + # then we simply take all the pixel centers in the field by counting + # window size pixels from the left boundary + ordinal_numbers = fw_arange(ceil(cur_support_sz - eps), fw, device) + return left_boundaries[:, None] + ordinal_numbers + + def calc_pad_sz(self, in_sz, out_sz, field_of_view, projected_grid, scale_factor, dim_by_convs, fw, device): + if not dim_by_convs: + # determine padding according to neighbor coords out of bound. + # this is a generalized notion of padding, when pad<0 it means crop + pad_sz = [-field_of_view[0, 0].item(), field_of_view[-1, -1].item() - in_sz + 1] + + # since input image will be changed by padding, coordinates of both + # field_of_view and projected_grid need to be updated + field_of_view += pad_sz[0] + projected_grid += pad_sz[0] + + else: + # only used for by_convs, to calc the boundaries of each filter the + # number of distinct convolutions is the numerator of the scale factor + num_convs, stride = scale_factor.numerator, scale_factor.denominator + + # calculate left and right boundaries for each conv. left can also be + # negative right can be bigger than in_sz. such cases imply padding if + # needed. however if# both are in-bounds, it means we need to crop, + # practically apply the conv only on part of the image. + left_pads = -field_of_view[:, 0] + + # next calc is tricky, explanation by rows: + # 1) counting output pixels between the first position of each filter + # to the right boundary of the input + # 2) dividing it by number of filters to count how many 'jumps' + # each filter does + # 3) multiplying by the stride gives us the distance over the input + # coords done by all these jumps for each filter + # 4) to this distance we add the right boundary of the filter when + # placed in its leftmost position. so now we get the right boundary + # of that filter in input coord. + # 5) the padding size needed is obtained by subtracting the rightmost + # input coordinate. if the result is positive padding is needed. if + # negative then negative padding means shaving off pixel columns. + right_pads = ( + ((out_sz - fw_arange(num_convs, fw, device) - 1) // num_convs) * stride # (1) # (2) # (3) + + field_of_view[:, -1] # (4) + - in_sz + + 1 + ) # (5) + + # in the by_convs case pad_sz is a list of left-right pairs. one per + # each filter + + pad_sz = list(zip(left_pads, right_pads)) + + return pad_sz, projected_grid, field_of_view + + def get_weights(self, interp_method, projected_grid, field_of_view): + # the set of weights per each output pixels is the result of the chosen + # interpolation method applied to the distances between projected grid + # locations and the pixel-centers in the field of view (distances are + # directed, can be positive or negative) + weights = interp_method(projected_grid[:, None] - field_of_view) + + # we now carefully normalize the weights to sum to 1 per each output pixel + sum_weights = weights.sum(1, keepdims=True) + sum_weights[sum_weights == 0] = 1 + return weights / sum_weights + + def apply_weights(self, input, field_of_view, weights, dim, n_dims, pad_sz, pad_mode, fw): + # for this operation we assume the resized dim is the first one. + # so we transpose and will transpose back after multiplying + tmp_input = fw_swapaxes(input, dim, 0, fw) + + # apply padding + tmp_input = fw_pad(tmp_input, fw, pad_sz, pad_mode) + + # field_of_view is a tensor of order 2: for each output (1d location + # along cur dim)- a list of 1d neighbors locations. + # note that this whole operations is applied to each dim separately, + # this is why it is all in 1d. + # neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1: + # for each output pixel (this time indicated in all dims), these are the + # values of the neighbors in the 1d field of view. note that we only + # consider neighbors along the current dim, but such set exists for every + # multi-dim location, hence the final tensor order is image_dims+1. + neighbors = tmp_input[field_of_view] + + # weights is an order 2 tensor: for each output location along 1d- a list + # of weights matching the field of view. we augment it with ones, for + # broadcasting, so that when multiplies some tensor the weights affect + # only its first dim. + tmp_weights = fw.reshape(weights, (*weights.shape, *[1] * (n_dims - 1))) + + # now we simply multiply the weights with the neighbors, and then sum + # along the field of view, to get a single value per out pixel + tmp_output = (neighbors * tmp_weights).sum(1) + + # we transpose back the resized dim to its original position + return fw_swapaxes(tmp_output, 0, dim, fw) + + def apply_convs(self, input, scale_factor, in_sz, out_sz, weights, dim, pad_sz, pad_mode, fw): + # for this operations we assume the resized dim is the last one. + # so we transpose and will transpose back after multiplying + input = fw_swapaxes(input, dim, -1, fw) + + # the stride for all convs is the denominator of the scale factor + stride, num_convs = scale_factor.denominator, scale_factor.numerator + + # prepare an empty tensor for the output + tmp_out_shape = list(input.shape) + tmp_out_shape[-1] = out_sz + tmp_output = fw_empty(tuple(tmp_out_shape), fw, input.device) + + # iterate over the conv operations. we have as many as the numerator + # of the scale-factor. for each we need boundaries and a filter. + for conv_ind, (pad_sz, filt) in enumerate(zip(pad_sz, weights)): + # apply padding (we pad last dim, padding can be negative) + pad_dim = input.ndim - 1 + tmp_input = fw_pad(input, fw, pad_sz, pad_mode, dim=pad_dim) + + # apply convolution over last dim. store in the output tensor with + # positional strides so that when the loop is comlete conv results are + # interwind + tmp_output[..., conv_ind::num_convs] = fw_conv(tmp_input, filt, stride) + + return fw_swapaxes(tmp_output, -1, dim, fw) + + def set_scale_and_out_sz( + self, in_shape, out_shape, scale_factors, by_convs, scale_tolerance, max_numerator, eps, fw + ): + # eventually we must have both scale-factors and out-sizes for all in/out + # dims. however, we support many possible partial arguments + if scale_factors is None and out_shape is None: + raise ValueError("either scale_factors or out_shape should be " "provided") + if out_shape is not None: + # if out_shape has less dims than in_shape, we defaultly resize the + # first dims for numpy and last dims for torch + out_shape = ( + list(out_shape) + list(in_shape[len(out_shape) :]) + if fw is numpy + else list(in_shape[: -len(out_shape)]) + list(out_shape) + ) + if scale_factors is None: + # if no scale given, we calculate it as the out to in ratio + # (not recomended) + scale_factors = [out_sz / in_sz for out_sz, in_sz in zip(out_shape, in_shape)] + if scale_factors is not None: + # by default, if a single number is given as scale, we assume resizing + # two dims (most common are images with 2 spatial dims) + scale_factors = ( + scale_factors if isinstance(scale_factors, (list, tuple)) else [scale_factors, scale_factors] + ) + # if less scale_factors than in_shape dims, we defaultly resize the + # first dims for numpy and last dims for torch + scale_factors = ( + list(scale_factors) + [1] * (len(in_shape) - len(scale_factors)) + if fw is numpy + else [1] * (len(in_shape) - len(scale_factors)) + list(scale_factors) + ) + if out_shape is None: + # when no out_shape given, it is calculated by multiplying the + # scale by the in_shape (not recomended) + out_shape = [ceil(scale_factor * in_sz) for scale_factor, in_sz in zip(scale_factors, in_shape)] + # next part intentionally after out_shape determined for stability + # we fix by_convs to be a list of truth values in case it is not + if not isinstance(by_convs, (list, tuple)): + by_convs = [by_convs] * len(out_shape) + + # next loop fixes the scale for each dim to be either frac or float. + # this is determined by by_convs and by tolerance for scale accuracy. + for ind, (sf, dim_by_convs) in enumerate(zip(scale_factors, by_convs)): + # first we fractionaize + if dim_by_convs: + frac = Fraction(1 / sf).limit_denominator(max_numerator) + frac = Fraction(numerator=frac.denominator, denominator=frac.numerator) + + # if accuracy is within tolerance scale will be frac. if not, then + # it will be float and the by_convs attr will be set false for + # this dim + if scale_tolerance is None: + scale_tolerance = eps + if dim_by_convs and abs(frac - sf) < scale_tolerance: + scale_factors[ind] = frac + else: + scale_factors[ind] = float(sf) + by_convs[ind] = False + + return scale_factors, out_shape, by_convs + + def apply_antialiasing_if_needed(self, interp_method, support_sz, scale_factor, antialiasing): + # antialiasing is "stretching" the field of view according to the scale + # factor (only for downscaling). this is low-pass filtering. this + # requires modifying both the interpolation (stretching the 1d + # function and multiplying by the scale-factor) and the window size. + scale_factor = float(scale_factor) + if scale_factor >= 1.0 or not antialiasing: + return interp_method, support_sz + cur_interp_method = lambda arg: scale_factor * interp_method(scale_factor * arg) + cur_support_sz = support_sz / scale_factor + return cur_interp_method, cur_support_sz + + +def fw_ceil(x, fw): + if fw is numpy: + return fw.int_(fw.ceil(x)) + else: + return x.ceil().long() + + +def fw_floor(x, fw): + if fw is numpy: + return fw.int_(fw.floor(x)) + else: + return x.floor().long() + + +def fw_cat(x, fw): + if fw is numpy: + return fw.concatenate(x) + else: + return fw.cat(x) + + +def fw_swapaxes(x, ax_1, ax_2, fw): + if fw is numpy: + return fw.swapaxes(x, ax_1, ax_2) + else: + return x.transpose(ax_1, ax_2) + + +def fw_pad(x, fw, pad_sz, pad_mode, dim=0): + if pad_sz == (0, 0): + return x + if fw is numpy: + pad_vec = [(0, 0)] * x.ndim + pad_vec[dim] = pad_sz + return fw.pad(x, pad_width=pad_vec, mode=pad_mode) + else: + if x.ndim < 3: + x = x[None, None, ...] + + pad_vec = [0] * ((x.ndim - 2) * 2) + pad_vec[0:2] = pad_sz + return fw.nn.functional.pad(x.transpose(dim, -1), pad=pad_vec, mode=pad_mode).transpose(dim, -1) + + +def fw_conv(input, filter, stride): + # we want to apply 1d conv to any nd array. the way to do it is to reshape + # the input to a 4D tensor. first two dims are singeletons, 3rd dim stores + # all the spatial dims that we are not convolving along now. then we can + # apply conv2d with a 1xK filter. This convolves the same way all the other + # dims stored in the 3d dim. like depthwise conv over these. + # TODO: numpy support + reshaped_input = input.reshape(1, 1, -1, input.shape[-1]) + reshaped_output = torch.nn.functional.conv2d(reshaped_input, filter.view(1, 1, 1, -1), stride=(1, stride)) + return reshaped_output.reshape(*input.shape[:-1], -1) + + +def fw_arange(upper_bound, fw, device): + if fw is numpy: + return fw.arange(upper_bound) + else: + return fw.arange(upper_bound, device=device) + + +def fw_empty(shape, fw, device): + if fw is numpy: + return fw.empty(shape) + else: + return fw.empty(size=(*shape,), device=device) diff --git a/guided_diffusion/guided_diffusion/respace.py b/guided_diffusion/guided_diffusion/respace.py new file mode 100644 index 0000000..849f342 --- /dev/null +++ b/guided_diffusion/guided_diffusion/respace.py @@ -0,0 +1,144 @@ +# Copyright (c) 2021 OpenAI +# SPDX-License-Identifier: MIT + +# Code adapted from https://github.com/openai/guided-diffusion/tree/main/guided_diffusion -- MIT License +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError(f"cannot divide section of {size} steps into {section_count}") + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model_condn(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model_condn(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.timestep_map, self.rescale_timesteps, self.original_num_steps) + + def _wrap_model_condn(self, model): + if isinstance(model, _WrappedModel_condn): + return model + return _WrappedModel_condn(model, self.timestep_map, self.rescale_timesteps, self.original_num_steps) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) + + +class _WrappedModel_condn: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, x_in, eps, **kwargs): + # print("jello") + # stop + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, x_in, eps, **kwargs) diff --git a/guided_diffusion/guided_diffusion/script_util.py b/guided_diffusion/guided_diffusion/script_util.py new file mode 100644 index 0000000..8ed4d9a --- /dev/null +++ b/guided_diffusion/guided_diffusion/script_util.py @@ -0,0 +1,450 @@ +# Copyright (c) 2021 OpenAI +# SPDX-License-Identifier: MIT + +# Code adapted from https://github.com/openai/guided-diffusion/tree/main/guided_diffusion -- MIT License +import argparse +import inspect + +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps +from .unet import EncoderUNetModel, SuperResModel, UNetModel + +NUM_CLASSES = 1000 + + +def diffusion_defaults(): + """ + Defaults for image and classifier training. + """ + return dict( + learn_sigma=False, + diffusion_steps=1000, + noise_schedule="linear", + timestep_respacing="", + use_kl=False, + predict_xstart=False, + rescale_timesteps=False, + rescale_learned_sigmas=False, + ) + + +def classifier_defaults(): + """ + Defaults for classifier models. + """ + return dict( + image_size=64, + classifier_use_fp16=False, + classifier_width=128, + classifier_depth=2, + classifier_attention_resolutions="32,16,8", # 16 + classifier_use_scale_shift_norm=True, # False + classifier_resblock_updown=True, # False + classifier_pool="attention", + ) + + +def model_and_diffusion_defaults(): + """ + Defaults for image training. + """ + res = dict( + image_size=64, + num_channels=128, + num_res_blocks=2, + num_heads=4, + num_heads_upsample=-1, + num_head_channels=-1, + attention_resolutions="16,8", + channel_mult="", + dropout=0.0, + class_cond=False, + use_checkpoint=False, + use_scale_shift_norm=True, + resblock_updown=False, + use_fp16=False, + use_new_attention_order=False, + ) + res.update(diffusion_defaults()) + return res + + +def classifier_and_diffusion_defaults(): + res = classifier_defaults() + res.update(diffusion_defaults()) + return res + + +def create_model_and_diffusion( + image_size, + class_cond, + learn_sigma, + num_channels, + num_res_blocks, + channel_mult, + num_heads, + num_head_channels, + num_heads_upsample, + attention_resolutions, + dropout, + diffusion_steps, + noise_schedule, + timestep_respacing, + use_kl, + predict_xstart, + rescale_timesteps, + rescale_learned_sigmas, + use_checkpoint, + use_scale_shift_norm, + resblock_updown, + use_fp16, + use_new_attention_order, +): + model = create_model( + image_size, + num_channels, + num_res_blocks, + channel_mult=channel_mult, + learn_sigma=learn_sigma, + class_cond=class_cond, + use_checkpoint=use_checkpoint, + attention_resolutions=attention_resolutions, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + dropout=dropout, + resblock_updown=resblock_updown, + use_fp16=use_fp16, + use_new_attention_order=use_new_attention_order, + ) + diffusion = create_gaussian_diffusion( + steps=diffusion_steps, + learn_sigma=learn_sigma, + noise_schedule=noise_schedule, + use_kl=use_kl, + predict_xstart=predict_xstart, + rescale_timesteps=rescale_timesteps, + rescale_learned_sigmas=rescale_learned_sigmas, + timestep_respacing=timestep_respacing, + ) + return model, diffusion + + +def create_model( + image_size, + num_channels, + num_res_blocks, + channel_mult="", + learn_sigma=False, + class_cond=False, + use_checkpoint=False, + attention_resolutions="16", + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + dropout=0, + resblock_updown=False, + use_fp16=False, + use_new_attention_order=False, +): + if channel_mult == "": + if image_size == 512: + channel_mult = (0.5, 1, 1, 2, 2, 4, 4) + elif image_size == 256: + channel_mult = (1, 1, 2, 2, 4, 4) + elif image_size == 128: + channel_mult = (1, 1, 2, 3, 4) + elif image_size == 64: + channel_mult = (1, 2, 3, 4) + else: + raise ValueError(f"unsupported image size: {image_size}") + else: + channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) + + attention_ds = [] + for res in attention_resolutions.split(","): + attention_ds.append(image_size // int(res)) + + return UNetModel( + image_size=image_size, + in_channels=3, + model_channels=num_channels, + out_channels=(3 if not learn_sigma else 6), + num_res_blocks=num_res_blocks, + attention_resolutions=tuple(attention_ds), + dropout=dropout, + channel_mult=channel_mult, + num_classes=(NUM_CLASSES if class_cond else None), + use_checkpoint=use_checkpoint, + use_fp16=use_fp16, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + use_new_attention_order=use_new_attention_order, + ) + + +def create_classifier_and_diffusion( + image_size, + classifier_use_fp16, + classifier_width, + classifier_depth, + classifier_attention_resolutions, + classifier_use_scale_shift_norm, + classifier_resblock_updown, + classifier_pool, + learn_sigma, + diffusion_steps, + noise_schedule, + timestep_respacing, + use_kl, + predict_xstart, + rescale_timesteps, + rescale_learned_sigmas, +): + classifier = create_classifier( + image_size, + classifier_use_fp16, + classifier_width, + classifier_depth, + classifier_attention_resolutions, + classifier_use_scale_shift_norm, + classifier_resblock_updown, + classifier_pool, + ) + diffusion = create_gaussian_diffusion( + steps=diffusion_steps, + learn_sigma=learn_sigma, + noise_schedule=noise_schedule, + use_kl=use_kl, + predict_xstart=predict_xstart, + rescale_timesteps=rescale_timesteps, + rescale_learned_sigmas=rescale_learned_sigmas, + timestep_respacing=timestep_respacing, + ) + return classifier, diffusion + + +def create_classifier( + image_size, + classifier_use_fp16, + classifier_width, + classifier_depth, + classifier_attention_resolutions, + classifier_use_scale_shift_norm, + classifier_resblock_updown, + classifier_pool, +): + if image_size == 512: + channel_mult = (0.5, 1, 1, 2, 2, 4, 4) + elif image_size == 256: + channel_mult = (1, 1, 2, 2, 4, 4) + elif image_size == 128: + channel_mult = (1, 1, 2, 3, 4) + elif image_size == 64: + channel_mult = (1, 2, 3, 4) + else: + raise ValueError(f"unsupported image size: {image_size}") + + attention_ds = [] + for res in classifier_attention_resolutions.split(","): + attention_ds.append(image_size // int(res)) + + return EncoderUNetModel( + image_size=image_size, + in_channels=3, + model_channels=classifier_width, + out_channels=1000, + num_res_blocks=classifier_depth, + attention_resolutions=tuple(attention_ds), + channel_mult=channel_mult, + use_fp16=classifier_use_fp16, + num_head_channels=64, + use_scale_shift_norm=classifier_use_scale_shift_norm, + resblock_updown=classifier_resblock_updown, + pool=classifier_pool, + ) + + +def sr_model_and_diffusion_defaults(): + res = model_and_diffusion_defaults() + res["large_size"] = 256 + res["small_size"] = 64 + arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] + for k in res.copy().keys(): + if k not in arg_names: + del res[k] + return res + + +def sr_create_model_and_diffusion( + large_size, + small_size, + class_cond, + learn_sigma, + num_channels, + num_res_blocks, + num_heads, + num_head_channels, + num_heads_upsample, + attention_resolutions, + dropout, + diffusion_steps, + noise_schedule, + timestep_respacing, + use_kl, + predict_xstart, + rescale_timesteps, + rescale_learned_sigmas, + use_checkpoint, + use_scale_shift_norm, + resblock_updown, + use_fp16, +): + model = sr_create_model( + large_size, + small_size, + num_channels, + num_res_blocks, + learn_sigma=learn_sigma, + class_cond=class_cond, + use_checkpoint=use_checkpoint, + attention_resolutions=attention_resolutions, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + dropout=dropout, + resblock_updown=resblock_updown, + use_fp16=use_fp16, + ) + diffusion = create_gaussian_diffusion( + steps=diffusion_steps, + learn_sigma=learn_sigma, + noise_schedule=noise_schedule, + use_kl=use_kl, + predict_xstart=predict_xstart, + rescale_timesteps=rescale_timesteps, + rescale_learned_sigmas=rescale_learned_sigmas, + timestep_respacing=timestep_respacing, + ) + return model, diffusion + + +def sr_create_model( + large_size, + small_size, + num_channels, + num_res_blocks, + learn_sigma, + class_cond, + use_checkpoint, + attention_resolutions, + num_heads, + num_head_channels, + num_heads_upsample, + use_scale_shift_norm, + dropout, + resblock_updown, + use_fp16, +): + _ = small_size # hack to prevent unused variable + + if large_size == 512: + channel_mult = (1, 1, 2, 2, 4, 4) + elif large_size == 256: + channel_mult = (1, 1, 2, 2, 4, 4) + elif large_size == 64: + channel_mult = (1, 2, 3, 4) + else: + raise ValueError(f"unsupported large size: {large_size}") + + attention_ds = [] + for res in attention_resolutions.split(","): + attention_ds.append(large_size // int(res)) + + return SuperResModel( + image_size=large_size, + in_channels=3, + model_channels=num_channels, + out_channels=(3 if not learn_sigma else 6), + num_res_blocks=num_res_blocks, + attention_resolutions=tuple(attention_ds), + dropout=dropout, + channel_mult=channel_mult, + num_classes=(NUM_CLASSES if class_cond else None), + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + use_fp16=use_fp16, + ) + + +def create_gaussian_diffusion( + *, + steps=1000, + learn_sigma=False, + sigma_small=False, + noise_schedule="linear", + use_kl=False, + predict_xstart=False, + rescale_timesteps=False, + rescale_learned_sigmas=False, + timestep_respacing="", +): + betas = gd.get_named_beta_schedule(noise_schedule, steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if not timestep_respacing: + timestep_respacing = [steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(steps, timestep_respacing), + betas=betas, + model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X), + model_var_type=( + (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + rescale_timesteps=rescale_timesteps, + ) + + +def add_dict_to_argparser(parser, default_dict): + for k, v in default_dict.items(): + v_type = type(v) + if v is None: + v_type = str + elif isinstance(v, bool): + v_type = str2bool + parser.add_argument(f"--{k}", default=v, type=v_type) + + +def args_to_dict(args, keys): + return {k: getattr(args, k) for k in keys} + + +def str2bool(v): + """ + https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse + """ + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("boolean value expected") diff --git a/guided_diffusion/guided_diffusion/unet.py b/guided_diffusion/guided_diffusion/unet.py new file mode 100644 index 0000000..4bcc1fe --- /dev/null +++ b/guided_diffusion/guided_diffusion/unet.py @@ -0,0 +1,866 @@ +# Copyright (c) 2021 OpenAI +# SPDX-License-Identifier: MIT + +import math + +# Code adapted from https://github.com/openai/guided-diffusion/tree/main/guided_diffusion -- MIT License +from abc import abstractmethod + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .fp16_util import convert_module_to_f16, convert_module_to_f32 +from .nn import avg_pool_nd, checkpoint, conv_nd, linear, normalization, timestep_embedding, zero_module + + +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + ch = input_ch = int(channel_mult[0] * model_channels) + self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=int(mult * model_channels), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(mult * model_channels) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=int(model_channels * mult), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(model_channels * mult) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps, y=None): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + hs = [] + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb) + h = h.type(x.dtype) + return self.out(h) + + +class SuperResModel(UNetModel): + """ + A UNetModel that performs super-resolution. + + Expects an extra kwarg `low_res` to condition on a low-resolution image. + """ + + def __init__(self, image_size, in_channels, *args, **kwargs): + super().__init__(image_size, in_channels * 2, *args, **kwargs) + + def forward(self, x, timesteps, low_res=None, **kwargs): + _, _, new_height, new_width = x.shape + upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") + x = th.cat([x, upsampled], dim=1) + return super().forward(x, timesteps, **kwargs) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + ch = int(channel_mult[0] * model_channels) + self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=int(mult * model_channels), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(mult * model_channels) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) diff --git a/input_example/hairs/00000_hair.png b/input_example/hairs/00000_hair.png new file mode 100644 index 0000000..b65f478 Binary files /dev/null and b/input_example/hairs/00000_hair.png differ diff --git a/input_example/hairs/00001_hair.png b/input_example/hairs/00001_hair.png new file mode 100644 index 0000000..ac1da36 Binary files /dev/null and b/input_example/hairs/00001_hair.png differ diff --git a/input_example/hairs/00002_hair.png b/input_example/hairs/00002_hair.png new file mode 100644 index 0000000..00d0a0a Binary files /dev/null and b/input_example/hairs/00002_hair.png differ diff --git a/input_example/hairs/00003_hair.png b/input_example/hairs/00003_hair.png new file mode 100644 index 0000000..1a22829 Binary files /dev/null and b/input_example/hairs/00003_hair.png differ diff --git a/input_example/hairs/00004_hair.png b/input_example/hairs/00004_hair.png new file mode 100644 index 0000000..4f3f0be Binary files /dev/null and b/input_example/hairs/00004_hair.png differ diff --git a/input_example/hairs/00005_hair.png b/input_example/hairs/00005_hair.png new file mode 100644 index 0000000..d408097 Binary files /dev/null and b/input_example/hairs/00005_hair.png differ diff --git a/input_example/hairs/00006_hair.png b/input_example/hairs/00006_hair.png new file mode 100644 index 0000000..45366f3 Binary files /dev/null and b/input_example/hairs/00006_hair.png differ diff --git a/input_example/hairs/00007_hair.png b/input_example/hairs/00007_hair.png new file mode 100644 index 0000000..f7d6eba Binary files /dev/null and b/input_example/hairs/00007_hair.png differ diff --git a/input_example/hairs/00008_hair.png b/input_example/hairs/00008_hair.png new file mode 100644 index 0000000..0b0cfa4 Binary files /dev/null and b/input_example/hairs/00008_hair.png differ diff --git a/input_example/hairs/00009_hair.png b/input_example/hairs/00009_hair.png new file mode 100644 index 0000000..00f283a Binary files /dev/null and b/input_example/hairs/00009_hair.png differ diff --git a/input_example/hairs/00010_hair.png b/input_example/hairs/00010_hair.png new file mode 100644 index 0000000..fcea06b Binary files /dev/null and b/input_example/hairs/00010_hair.png differ diff --git a/input_example/hairs/00011_hair.png b/input_example/hairs/00011_hair.png new file mode 100644 index 0000000..5755923 Binary files /dev/null and b/input_example/hairs/00011_hair.png differ diff --git a/input_example/hairs/00012_hair.png b/input_example/hairs/00012_hair.png new file mode 100644 index 0000000..ce77d2c Binary files /dev/null and b/input_example/hairs/00012_hair.png differ diff --git a/input_example/hairs/00013_hair.png b/input_example/hairs/00013_hair.png new file mode 100644 index 0000000..ec72d58 Binary files /dev/null and b/input_example/hairs/00013_hair.png differ diff --git a/input_example/hairs/00014_hair.png b/input_example/hairs/00014_hair.png new file mode 100644 index 0000000..319da2b Binary files /dev/null and b/input_example/hairs/00014_hair.png differ diff --git a/input_example/hairs/00015_hair.png b/input_example/hairs/00015_hair.png new file mode 100644 index 0000000..5a57811 Binary files /dev/null and b/input_example/hairs/00015_hair.png differ diff --git a/input_example/hairs/00016_hair.png b/input_example/hairs/00016_hair.png new file mode 100644 index 0000000..f7b4b7b Binary files /dev/null and b/input_example/hairs/00016_hair.png differ diff --git a/input_example/hairs/00017_hair.png b/input_example/hairs/00017_hair.png new file mode 100644 index 0000000..bf9db07 Binary files /dev/null and b/input_example/hairs/00017_hair.png differ diff --git a/input_example/hairs/00018_hair.png b/input_example/hairs/00018_hair.png new file mode 100644 index 0000000..cb8ad96 Binary files /dev/null and b/input_example/hairs/00018_hair.png differ diff --git a/input_example/hairs/00019_hair.png b/input_example/hairs/00019_hair.png new file mode 100644 index 0000000..2f4bf04 Binary files /dev/null and b/input_example/hairs/00019_hair.png differ diff --git a/input_example/hairs/00020_hair.png b/input_example/hairs/00020_hair.png new file mode 100644 index 0000000..66dc081 Binary files /dev/null and b/input_example/hairs/00020_hair.png differ diff --git a/input_example/hairs/00021_hair.png b/input_example/hairs/00021_hair.png new file mode 100644 index 0000000..9679488 Binary files /dev/null and b/input_example/hairs/00021_hair.png differ diff --git a/input_example/hairs/00022_hair.png b/input_example/hairs/00022_hair.png new file mode 100644 index 0000000..4a7cd82 Binary files /dev/null and b/input_example/hairs/00022_hair.png differ diff --git a/input_example/hairs/00023_hair.png b/input_example/hairs/00023_hair.png new file mode 100644 index 0000000..dd65f08 Binary files /dev/null and b/input_example/hairs/00023_hair.png differ diff --git a/input_example/hairs/00024_hair.png b/input_example/hairs/00024_hair.png new file mode 100644 index 0000000..e03ac97 Binary files /dev/null and b/input_example/hairs/00024_hair.png differ diff --git a/input_example/hairs/00025_hair.png b/input_example/hairs/00025_hair.png new file mode 100644 index 0000000..4d9a562 Binary files /dev/null and b/input_example/hairs/00025_hair.png differ diff --git a/input_example/hairs/00026_hair.png b/input_example/hairs/00026_hair.png new file mode 100644 index 0000000..8e8e019 Binary files /dev/null and b/input_example/hairs/00026_hair.png differ diff --git a/input_example/hairs/00027_hair.png b/input_example/hairs/00027_hair.png new file mode 100644 index 0000000..f0c1851 Binary files /dev/null and b/input_example/hairs/00027_hair.png differ diff --git a/input_example/hairs/00028_hair.png b/input_example/hairs/00028_hair.png new file mode 100644 index 0000000..2468a4c Binary files /dev/null and b/input_example/hairs/00028_hair.png differ diff --git a/input_example/hairs/00029_hair.png b/input_example/hairs/00029_hair.png new file mode 100644 index 0000000..cbcd215 Binary files /dev/null and b/input_example/hairs/00029_hair.png differ diff --git a/input_example/hairs/00030_hair.png b/input_example/hairs/00030_hair.png new file mode 100644 index 0000000..3b09936 Binary files /dev/null and b/input_example/hairs/00030_hair.png differ diff --git a/input_example/hairs/00031_hair.png b/input_example/hairs/00031_hair.png new file mode 100644 index 0000000..c90ec62 Binary files /dev/null and b/input_example/hairs/00031_hair.png differ diff --git a/input_example/hairs/00032_hair.png b/input_example/hairs/00032_hair.png new file mode 100644 index 0000000..ff18adc Binary files /dev/null and b/input_example/hairs/00032_hair.png differ diff --git a/input_example/hairs/00033_hair.png b/input_example/hairs/00033_hair.png new file mode 100644 index 0000000..21a0e49 Binary files /dev/null and b/input_example/hairs/00033_hair.png differ diff --git a/input_example/hairs/00034_hair.png b/input_example/hairs/00034_hair.png new file mode 100644 index 0000000..dcf6f60 Binary files /dev/null and b/input_example/hairs/00034_hair.png differ diff --git a/input_example/hairs/00035_hair.png b/input_example/hairs/00035_hair.png new file mode 100644 index 0000000..4fa9335 Binary files /dev/null and b/input_example/hairs/00035_hair.png differ diff --git a/input_example/hairs/00036_hair.png b/input_example/hairs/00036_hair.png new file mode 100644 index 0000000..f0726ed Binary files /dev/null and b/input_example/hairs/00036_hair.png differ diff --git a/input_example/hairs/00037_hair.png b/input_example/hairs/00037_hair.png new file mode 100644 index 0000000..d6d0222 Binary files /dev/null and b/input_example/hairs/00037_hair.png differ diff --git a/input_example/hairs/00038_hair.png b/input_example/hairs/00038_hair.png new file mode 100644 index 0000000..c5624fe Binary files /dev/null and b/input_example/hairs/00038_hair.png differ diff --git a/input_example/hairs/00039_hair.png b/input_example/hairs/00039_hair.png new file mode 100644 index 0000000..75af7cd Binary files /dev/null and b/input_example/hairs/00039_hair.png differ diff --git a/input_example/hairs/00040_hair.png b/input_example/hairs/00040_hair.png new file mode 100644 index 0000000..238c6e0 Binary files /dev/null and b/input_example/hairs/00040_hair.png differ diff --git a/input_example/hairs/00041_hair.png b/input_example/hairs/00041_hair.png new file mode 100644 index 0000000..0d9fa4f Binary files /dev/null and b/input_example/hairs/00041_hair.png differ diff --git a/input_example/hairs/00042_hair.png b/input_example/hairs/00042_hair.png new file mode 100644 index 0000000..2198b23 Binary files /dev/null and b/input_example/hairs/00042_hair.png differ diff --git a/input_example/hairs/00043_hair.png b/input_example/hairs/00043_hair.png new file mode 100644 index 0000000..a2ae13c Binary files /dev/null and b/input_example/hairs/00043_hair.png differ diff --git a/input_example/hairs/00044_hair.png b/input_example/hairs/00044_hair.png new file mode 100644 index 0000000..1f8fca2 Binary files /dev/null and b/input_example/hairs/00044_hair.png differ diff --git a/input_example/hairs/00045_hair.png b/input_example/hairs/00045_hair.png new file mode 100644 index 0000000..17ea1a6 Binary files /dev/null and b/input_example/hairs/00045_hair.png differ diff --git a/input_example/hairs/00046_hair.png b/input_example/hairs/00046_hair.png new file mode 100644 index 0000000..8c28e2a Binary files /dev/null and b/input_example/hairs/00046_hair.png differ diff --git a/input_example/hairs/00047_hair.png b/input_example/hairs/00047_hair.png new file mode 100644 index 0000000..ab1e32b Binary files /dev/null and b/input_example/hairs/00047_hair.png differ diff --git a/input_example/hairs/00048_hair.png b/input_example/hairs/00048_hair.png new file mode 100644 index 0000000..7a19391 Binary files /dev/null and b/input_example/hairs/00048_hair.png differ diff --git a/input_example/hairs/00049_hair.png b/input_example/hairs/00049_hair.png new file mode 100644 index 0000000..8b7b883 Binary files /dev/null and b/input_example/hairs/00049_hair.png differ diff --git a/input_example/hairs/00050_hair.png b/input_example/hairs/00050_hair.png new file mode 100644 index 0000000..869c1fe Binary files /dev/null and b/input_example/hairs/00050_hair.png differ diff --git a/input_example/hairs/00051_hair.png b/input_example/hairs/00051_hair.png new file mode 100644 index 0000000..859bda6 Binary files /dev/null and b/input_example/hairs/00051_hair.png differ diff --git a/input_example/hairs/00052_hair.png b/input_example/hairs/00052_hair.png new file mode 100644 index 0000000..d41836e Binary files /dev/null and b/input_example/hairs/00052_hair.png differ diff --git a/input_example/hairs/00053_hair.png b/input_example/hairs/00053_hair.png new file mode 100644 index 0000000..a267c85 Binary files /dev/null and b/input_example/hairs/00053_hair.png differ diff --git a/input_example/hairs/00054_hair.png b/input_example/hairs/00054_hair.png new file mode 100644 index 0000000..4e0609e Binary files /dev/null and b/input_example/hairs/00054_hair.png differ diff --git a/input_example/hairs/00055_hair.png b/input_example/hairs/00055_hair.png new file mode 100644 index 0000000..5ca4b83 Binary files /dev/null and b/input_example/hairs/00055_hair.png differ diff --git a/input_example/hairs/00056_hair.png b/input_example/hairs/00056_hair.png new file mode 100644 index 0000000..3d01c81 Binary files /dev/null and b/input_example/hairs/00056_hair.png differ diff --git a/input_example/hairs/00057_hair.png b/input_example/hairs/00057_hair.png new file mode 100644 index 0000000..0827f95 Binary files /dev/null and b/input_example/hairs/00057_hair.png differ diff --git a/input_example/hairs/00058_hair.png b/input_example/hairs/00058_hair.png new file mode 100644 index 0000000..becab20 Binary files /dev/null and b/input_example/hairs/00058_hair.png differ diff --git a/input_example/hairs/00059_hair.png b/input_example/hairs/00059_hair.png new file mode 100644 index 0000000..242c916 Binary files /dev/null and b/input_example/hairs/00059_hair.png differ diff --git a/input_example/hairs/00060_hair.png b/input_example/hairs/00060_hair.png new file mode 100644 index 0000000..e8a3b13 Binary files /dev/null and b/input_example/hairs/00060_hair.png differ diff --git a/input_example/hairs/00061_hair.png b/input_example/hairs/00061_hair.png new file mode 100644 index 0000000..1385316 Binary files /dev/null and b/input_example/hairs/00061_hair.png differ diff --git a/input_example/hairs/00062_hair.png b/input_example/hairs/00062_hair.png new file mode 100644 index 0000000..2138614 Binary files /dev/null and b/input_example/hairs/00062_hair.png differ diff --git a/input_example/hairs/00063_hair.png b/input_example/hairs/00063_hair.png new file mode 100644 index 0000000..6707c1f Binary files /dev/null and b/input_example/hairs/00063_hair.png differ diff --git a/input_example/hairs/00065_hair.png b/input_example/hairs/00065_hair.png new file mode 100644 index 0000000..2267ee8 Binary files /dev/null and b/input_example/hairs/00065_hair.png differ diff --git a/input_example/hairs/00066_hair.png b/input_example/hairs/00066_hair.png new file mode 100644 index 0000000..554f004 Binary files /dev/null and b/input_example/hairs/00066_hair.png differ diff --git a/input_example/hairs/00067_hair.png b/input_example/hairs/00067_hair.png new file mode 100644 index 0000000..d817fcb Binary files /dev/null and b/input_example/hairs/00067_hair.png differ diff --git a/input_example/hairs/00068_hair.png b/input_example/hairs/00068_hair.png new file mode 100644 index 0000000..5a55d8f Binary files /dev/null and b/input_example/hairs/00068_hair.png differ diff --git a/input_example/hairs/00069_hair.png b/input_example/hairs/00069_hair.png new file mode 100644 index 0000000..feb4ba0 Binary files /dev/null and b/input_example/hairs/00069_hair.png differ diff --git a/input_example/hairs/00070_hair.png b/input_example/hairs/00070_hair.png new file mode 100644 index 0000000..ba1fad9 Binary files /dev/null and b/input_example/hairs/00070_hair.png differ diff --git a/input_example/hairs/00071_hair.png b/input_example/hairs/00071_hair.png new file mode 100644 index 0000000..1a88a6c Binary files /dev/null and b/input_example/hairs/00071_hair.png differ diff --git a/input_example/hairs/00072_hair.png b/input_example/hairs/00072_hair.png new file mode 100644 index 0000000..f0891cc Binary files /dev/null and b/input_example/hairs/00072_hair.png differ diff --git a/input_example/hairs/00073_hair.png b/input_example/hairs/00073_hair.png new file mode 100644 index 0000000..04a4a5c Binary files /dev/null and b/input_example/hairs/00073_hair.png differ diff --git a/input_example/hairs/00074_hair.png b/input_example/hairs/00074_hair.png new file mode 100644 index 0000000..ee8d7fc Binary files /dev/null and b/input_example/hairs/00074_hair.png differ diff --git a/input_example/hairs/00075_hair.png b/input_example/hairs/00075_hair.png new file mode 100644 index 0000000..a379be8 Binary files /dev/null and b/input_example/hairs/00075_hair.png differ diff --git a/input_example/hairs/00076_hair.png b/input_example/hairs/00076_hair.png new file mode 100644 index 0000000..a33a432 Binary files /dev/null and b/input_example/hairs/00076_hair.png differ diff --git a/input_example/hairs/00077_hair.png b/input_example/hairs/00077_hair.png new file mode 100644 index 0000000..634f34c Binary files /dev/null and b/input_example/hairs/00077_hair.png differ diff --git a/input_example/hairs/00078_hair.png b/input_example/hairs/00078_hair.png new file mode 100644 index 0000000..c6da9f1 Binary files /dev/null and b/input_example/hairs/00078_hair.png differ diff --git a/input_example/hairs/00079_hair.png b/input_example/hairs/00079_hair.png new file mode 100644 index 0000000..7402e2e Binary files /dev/null and b/input_example/hairs/00079_hair.png differ diff --git a/input_example/hairs/00080_hair.png b/input_example/hairs/00080_hair.png new file mode 100644 index 0000000..670a833 Binary files /dev/null and b/input_example/hairs/00080_hair.png differ diff --git a/input_example/hairs/00081_hair.png b/input_example/hairs/00081_hair.png new file mode 100644 index 0000000..d2122e9 Binary files /dev/null and b/input_example/hairs/00081_hair.png differ diff --git a/input_example/hairs/00082_hair.png b/input_example/hairs/00082_hair.png new file mode 100644 index 0000000..e2dd70d Binary files /dev/null and b/input_example/hairs/00082_hair.png differ diff --git a/input_example/hairs/00083_hair.png b/input_example/hairs/00083_hair.png new file mode 100644 index 0000000..1064816 Binary files /dev/null and b/input_example/hairs/00083_hair.png differ diff --git a/input_example/hairs/00084_hair.png b/input_example/hairs/00084_hair.png new file mode 100644 index 0000000..a65e55b Binary files /dev/null and b/input_example/hairs/00084_hair.png differ diff --git a/input_example/hairs/00085_hair.png b/input_example/hairs/00085_hair.png new file mode 100644 index 0000000..b135abd Binary files /dev/null and b/input_example/hairs/00085_hair.png differ diff --git a/input_example/hairs/00086_hair.png b/input_example/hairs/00086_hair.png new file mode 100644 index 0000000..a2657cc Binary files /dev/null and b/input_example/hairs/00086_hair.png differ diff --git a/input_example/hairs/00087_hair.png b/input_example/hairs/00087_hair.png new file mode 100644 index 0000000..c86ffbb Binary files /dev/null and b/input_example/hairs/00087_hair.png differ diff --git a/input_example/hairs/00088_hair.png b/input_example/hairs/00088_hair.png new file mode 100644 index 0000000..7f57a05 Binary files /dev/null and b/input_example/hairs/00088_hair.png differ diff --git a/input_example/hairs/00089_hair.png b/input_example/hairs/00089_hair.png new file mode 100644 index 0000000..ffd4df1 Binary files /dev/null and b/input_example/hairs/00089_hair.png differ diff --git a/input_example/hairs/00090_hair.png b/input_example/hairs/00090_hair.png new file mode 100644 index 0000000..67e65a9 Binary files /dev/null and b/input_example/hairs/00090_hair.png differ diff --git a/input_example/hairs/00091_hair.png b/input_example/hairs/00091_hair.png new file mode 100644 index 0000000..c6a06fa Binary files /dev/null and b/input_example/hairs/00091_hair.png differ diff --git a/input_example/hairs/00092_hair.png b/input_example/hairs/00092_hair.png new file mode 100644 index 0000000..58c9d75 Binary files /dev/null and b/input_example/hairs/00092_hair.png differ diff --git a/input_example/hairs/00093_hair.png b/input_example/hairs/00093_hair.png new file mode 100644 index 0000000..a0a5ad7 Binary files /dev/null and b/input_example/hairs/00093_hair.png differ diff --git a/input_example/hairs/00095_hair.png b/input_example/hairs/00095_hair.png new file mode 100644 index 0000000..b0b58bf Binary files /dev/null and b/input_example/hairs/00095_hair.png differ diff --git a/input_example/hairs/00096_hair.png b/input_example/hairs/00096_hair.png new file mode 100644 index 0000000..b11163d Binary files /dev/null and b/input_example/hairs/00096_hair.png differ diff --git a/input_example/hairs/00097_hair.png b/input_example/hairs/00097_hair.png new file mode 100644 index 0000000..fe1e8bd Binary files /dev/null and b/input_example/hairs/00097_hair.png differ diff --git a/input_example/hairs/00098_hair.png b/input_example/hairs/00098_hair.png new file mode 100644 index 0000000..bb05084 Binary files /dev/null and b/input_example/hairs/00098_hair.png differ diff --git a/input_example/hairs/00099_hair.png b/input_example/hairs/00099_hair.png new file mode 100644 index 0000000..decf795 Binary files /dev/null and b/input_example/hairs/00099_hair.png differ diff --git a/input_example/hairs/00100_hair.png b/input_example/hairs/00100_hair.png new file mode 100644 index 0000000..547dfe0 Binary files /dev/null and b/input_example/hairs/00100_hair.png differ diff --git a/input_example/hairs/00101_hair.png b/input_example/hairs/00101_hair.png new file mode 100644 index 0000000..8267677 Binary files /dev/null and b/input_example/hairs/00101_hair.png differ diff --git a/input_example/hairs/00102_hair.png b/input_example/hairs/00102_hair.png new file mode 100644 index 0000000..4bcb1be Binary files /dev/null and b/input_example/hairs/00102_hair.png differ diff --git a/input_example/hairs/00103_hair.png b/input_example/hairs/00103_hair.png new file mode 100644 index 0000000..a504cb9 Binary files /dev/null and b/input_example/hairs/00103_hair.png differ diff --git a/input_example/hairs/00104_hair.png b/input_example/hairs/00104_hair.png new file mode 100644 index 0000000..486c4b0 Binary files /dev/null and b/input_example/hairs/00104_hair.png differ diff --git a/input_example/hairs/00105_hair.png b/input_example/hairs/00105_hair.png new file mode 100644 index 0000000..a5dee77 Binary files /dev/null and b/input_example/hairs/00105_hair.png differ diff --git a/input_example/hairs/00106_hair.png b/input_example/hairs/00106_hair.png new file mode 100644 index 0000000..f87cb18 Binary files /dev/null and b/input_example/hairs/00106_hair.png differ diff --git a/input_example/hairs/00107_hair.png b/input_example/hairs/00107_hair.png new file mode 100644 index 0000000..864d961 Binary files /dev/null and b/input_example/hairs/00107_hair.png differ diff --git a/input_example/hairs/00108_hair.png b/input_example/hairs/00108_hair.png new file mode 100644 index 0000000..2389947 Binary files /dev/null and b/input_example/hairs/00108_hair.png differ diff --git a/input_example/hairs/00109_hair.png b/input_example/hairs/00109_hair.png new file mode 100644 index 0000000..bbc2085 Binary files /dev/null and b/input_example/hairs/00109_hair.png differ diff --git a/input_example/hairs/00110_hair.png b/input_example/hairs/00110_hair.png new file mode 100644 index 0000000..7d5f99e Binary files /dev/null and b/input_example/hairs/00110_hair.png differ diff --git a/input_example/hairs/00111_hair.png b/input_example/hairs/00111_hair.png new file mode 100644 index 0000000..1ebf4da Binary files /dev/null and b/input_example/hairs/00111_hair.png differ diff --git a/input_example/hairs/00112_hair.png b/input_example/hairs/00112_hair.png new file mode 100644 index 0000000..612540d Binary files /dev/null and b/input_example/hairs/00112_hair.png differ diff --git a/input_example/hairs/00113_hair.png b/input_example/hairs/00113_hair.png new file mode 100644 index 0000000..33ab969 Binary files /dev/null and b/input_example/hairs/00113_hair.png differ diff --git a/input_example/hairs/00114_hair.png b/input_example/hairs/00114_hair.png new file mode 100644 index 0000000..4e893e2 Binary files /dev/null and b/input_example/hairs/00114_hair.png differ diff --git a/input_example/hairs/00115_hair.png b/input_example/hairs/00115_hair.png new file mode 100644 index 0000000..a364e6f Binary files /dev/null and b/input_example/hairs/00115_hair.png differ diff --git a/input_example/hairs/00116_hair.png b/input_example/hairs/00116_hair.png new file mode 100644 index 0000000..29ec6c6 Binary files /dev/null and b/input_example/hairs/00116_hair.png differ diff --git a/input_example/hairs/00117_hair.png b/input_example/hairs/00117_hair.png new file mode 100644 index 0000000..a50bf49 Binary files /dev/null and b/input_example/hairs/00117_hair.png differ diff --git a/input_example/hairs/00118_hair.png b/input_example/hairs/00118_hair.png new file mode 100644 index 0000000..78fbbed Binary files /dev/null and b/input_example/hairs/00118_hair.png differ diff --git a/input_example/hairs/00119_hair.png b/input_example/hairs/00119_hair.png new file mode 100644 index 0000000..ff13c84 Binary files /dev/null and b/input_example/hairs/00119_hair.png differ diff --git a/input_example/hairs/00120_hair.png b/input_example/hairs/00120_hair.png new file mode 100644 index 0000000..21766a3 Binary files /dev/null and b/input_example/hairs/00120_hair.png differ diff --git a/input_example/hairs/00121_hair.png b/input_example/hairs/00121_hair.png new file mode 100644 index 0000000..5ca72fb Binary files /dev/null and b/input_example/hairs/00121_hair.png differ diff --git a/input_example/hairs/00122_hair.png b/input_example/hairs/00122_hair.png new file mode 100644 index 0000000..0d544e4 Binary files /dev/null and b/input_example/hairs/00122_hair.png differ diff --git a/input_example/hairs/00123_hair.png b/input_example/hairs/00123_hair.png new file mode 100644 index 0000000..00b4efb Binary files /dev/null and b/input_example/hairs/00123_hair.png differ diff --git a/input_example/hairs/00124_hair.png b/input_example/hairs/00124_hair.png new file mode 100644 index 0000000..f4df126 Binary files /dev/null and b/input_example/hairs/00124_hair.png differ diff --git a/input_example/hairs/00125_hair.png b/input_example/hairs/00125_hair.png new file mode 100644 index 0000000..cb809ae Binary files /dev/null and b/input_example/hairs/00125_hair.png differ diff --git a/input_example/hairs/00126_hair.png b/input_example/hairs/00126_hair.png new file mode 100644 index 0000000..62627b2 Binary files /dev/null and b/input_example/hairs/00126_hair.png differ diff --git a/input_example/hairs/00127_hair.png b/input_example/hairs/00127_hair.png new file mode 100644 index 0000000..11c2d39 Binary files /dev/null and b/input_example/hairs/00127_hair.png differ diff --git a/input_example/hairs/00128_hair.png b/input_example/hairs/00128_hair.png new file mode 100644 index 0000000..eef161c Binary files /dev/null and b/input_example/hairs/00128_hair.png differ diff --git a/input_example/hairs/00129_hair.png b/input_example/hairs/00129_hair.png new file mode 100644 index 0000000..1e660f5 Binary files /dev/null and b/input_example/hairs/00129_hair.png differ diff --git a/input_example/hairs/00130_hair.png b/input_example/hairs/00130_hair.png new file mode 100644 index 0000000..66d2197 Binary files /dev/null and b/input_example/hairs/00130_hair.png differ diff --git a/input_example/hairs/00131_hair.png b/input_example/hairs/00131_hair.png new file mode 100644 index 0000000..fbedf0c Binary files /dev/null and b/input_example/hairs/00131_hair.png differ diff --git a/input_example/hairs/00132_hair.png b/input_example/hairs/00132_hair.png new file mode 100644 index 0000000..c064b62 Binary files /dev/null and b/input_example/hairs/00132_hair.png differ diff --git a/input_example/hairs/00133_hair.png b/input_example/hairs/00133_hair.png new file mode 100644 index 0000000..3ae72aa Binary files /dev/null and b/input_example/hairs/00133_hair.png differ diff --git a/input_example/hairs/00134_hair.png b/input_example/hairs/00134_hair.png new file mode 100644 index 0000000..1a7d02b Binary files /dev/null and b/input_example/hairs/00134_hair.png differ diff --git a/input_example/hairs/00135_hair.png b/input_example/hairs/00135_hair.png new file mode 100644 index 0000000..e57ee87 Binary files /dev/null and b/input_example/hairs/00135_hair.png differ diff --git a/input_example/hairs/00136_hair.png b/input_example/hairs/00136_hair.png new file mode 100644 index 0000000..2bc0f1f Binary files /dev/null and b/input_example/hairs/00136_hair.png differ diff --git a/input_example/hairs/00137_hair.png b/input_example/hairs/00137_hair.png new file mode 100644 index 0000000..f778de4 Binary files /dev/null and b/input_example/hairs/00137_hair.png differ diff --git a/input_example/hairs/00139_hair.png b/input_example/hairs/00139_hair.png new file mode 100644 index 0000000..75b02d0 Binary files /dev/null and b/input_example/hairs/00139_hair.png differ diff --git a/input_example/hairs/00140_hair.png b/input_example/hairs/00140_hair.png new file mode 100644 index 0000000..da77e9c Binary files /dev/null and b/input_example/hairs/00140_hair.png differ diff --git a/input_example/hairs/00141_hair.png b/input_example/hairs/00141_hair.png new file mode 100644 index 0000000..d24f945 Binary files /dev/null and b/input_example/hairs/00141_hair.png differ diff --git a/input_example/hairs/00142_hair.png b/input_example/hairs/00142_hair.png new file mode 100644 index 0000000..6d0515b Binary files /dev/null and b/input_example/hairs/00142_hair.png differ diff --git a/input_example/hairs/00143_hair.png b/input_example/hairs/00143_hair.png new file mode 100644 index 0000000..c5aaa63 Binary files /dev/null and b/input_example/hairs/00143_hair.png differ diff --git a/input_example/hairs/00144_hair.png b/input_example/hairs/00144_hair.png new file mode 100644 index 0000000..f32781c Binary files /dev/null and b/input_example/hairs/00144_hair.png differ diff --git a/input_example/hairs/00145_hair.png b/input_example/hairs/00145_hair.png new file mode 100644 index 0000000..479952e Binary files /dev/null and b/input_example/hairs/00145_hair.png differ diff --git a/input_example/hairs/00146_hair.png b/input_example/hairs/00146_hair.png new file mode 100644 index 0000000..21f47ba Binary files /dev/null and b/input_example/hairs/00146_hair.png differ diff --git a/input_example/hairs/00147_hair.png b/input_example/hairs/00147_hair.png new file mode 100644 index 0000000..4ef5d66 Binary files /dev/null and b/input_example/hairs/00147_hair.png differ diff --git a/input_example/hairs/00148_hair.png b/input_example/hairs/00148_hair.png new file mode 100644 index 0000000..f2795d4 Binary files /dev/null and b/input_example/hairs/00148_hair.png differ diff --git a/input_example/hairs/00149_hair.png b/input_example/hairs/00149_hair.png new file mode 100644 index 0000000..ccc44e3 Binary files /dev/null and b/input_example/hairs/00149_hair.png differ diff --git a/input_example/hairs/00150_hair.png b/input_example/hairs/00150_hair.png new file mode 100644 index 0000000..aa0b20e Binary files /dev/null and b/input_example/hairs/00150_hair.png differ diff --git a/input_example/hairs/00151_hair.png b/input_example/hairs/00151_hair.png new file mode 100644 index 0000000..a723b53 Binary files /dev/null and b/input_example/hairs/00151_hair.png differ diff --git a/input_example/hairs/00152_hair.png b/input_example/hairs/00152_hair.png new file mode 100644 index 0000000..b4b8927 Binary files /dev/null and b/input_example/hairs/00152_hair.png differ diff --git a/input_example/hairs/00153_hair.png b/input_example/hairs/00153_hair.png new file mode 100644 index 0000000..a798e26 Binary files /dev/null and b/input_example/hairs/00153_hair.png differ diff --git a/input_example/hairs/00154_hair.png b/input_example/hairs/00154_hair.png new file mode 100644 index 0000000..0edbbbf Binary files /dev/null and b/input_example/hairs/00154_hair.png differ diff --git a/input_example/hairs/00155_hair.png b/input_example/hairs/00155_hair.png new file mode 100644 index 0000000..ab470e7 Binary files /dev/null and b/input_example/hairs/00155_hair.png differ diff --git a/input_example/hairs/00156_hair.png b/input_example/hairs/00156_hair.png new file mode 100644 index 0000000..b73fdf1 Binary files /dev/null and b/input_example/hairs/00156_hair.png differ diff --git a/input_example/hairs/00157_hair.png b/input_example/hairs/00157_hair.png new file mode 100644 index 0000000..90a73f5 Binary files /dev/null and b/input_example/hairs/00157_hair.png differ diff --git a/input_example/hairs/00158_hair.png b/input_example/hairs/00158_hair.png new file mode 100644 index 0000000..8af0758 Binary files /dev/null and b/input_example/hairs/00158_hair.png differ diff --git a/input_example/hairs/00159_hair.png b/input_example/hairs/00159_hair.png new file mode 100644 index 0000000..e400796 Binary files /dev/null and b/input_example/hairs/00159_hair.png differ diff --git a/input_example/hairs/00160_hair.png b/input_example/hairs/00160_hair.png new file mode 100644 index 0000000..fd5793e Binary files /dev/null and b/input_example/hairs/00160_hair.png differ diff --git a/input_example/hairs/00161_hair.png b/input_example/hairs/00161_hair.png new file mode 100644 index 0000000..1737bcb Binary files /dev/null and b/input_example/hairs/00161_hair.png differ diff --git a/input_example/hairs/00162_hair.png b/input_example/hairs/00162_hair.png new file mode 100644 index 0000000..22ab1e9 Binary files /dev/null and b/input_example/hairs/00162_hair.png differ diff --git a/input_example/hairs/00163_hair.png b/input_example/hairs/00163_hair.png new file mode 100644 index 0000000..e621403 Binary files /dev/null and b/input_example/hairs/00163_hair.png differ diff --git a/input_example/hairs/00164_hair.png b/input_example/hairs/00164_hair.png new file mode 100644 index 0000000..9e673af Binary files /dev/null and b/input_example/hairs/00164_hair.png differ diff --git a/input_example/hairs/00165_hair.png b/input_example/hairs/00165_hair.png new file mode 100644 index 0000000..820b02b Binary files /dev/null and b/input_example/hairs/00165_hair.png differ diff --git a/input_example/hairs/00166_hair.png b/input_example/hairs/00166_hair.png new file mode 100644 index 0000000..12e2f9e Binary files /dev/null and b/input_example/hairs/00166_hair.png differ diff --git a/input_example/hairs/00167_hair.png b/input_example/hairs/00167_hair.png new file mode 100644 index 0000000..02621bc Binary files /dev/null and b/input_example/hairs/00167_hair.png differ diff --git a/input_example/hairs/00168_hair.png b/input_example/hairs/00168_hair.png new file mode 100644 index 0000000..3e4f4a6 Binary files /dev/null and b/input_example/hairs/00168_hair.png differ diff --git a/input_example/hairs/00169_hair.png b/input_example/hairs/00169_hair.png new file mode 100644 index 0000000..fb73596 Binary files /dev/null and b/input_example/hairs/00169_hair.png differ diff --git a/input_example/hairs/00170_hair.png b/input_example/hairs/00170_hair.png new file mode 100644 index 0000000..3945bdb Binary files /dev/null and b/input_example/hairs/00170_hair.png differ diff --git a/input_example/hairs/00171_hair.png b/input_example/hairs/00171_hair.png new file mode 100644 index 0000000..e0a09df Binary files /dev/null and b/input_example/hairs/00171_hair.png differ diff --git a/input_example/hairs/00172_hair.png b/input_example/hairs/00172_hair.png new file mode 100644 index 0000000..c1b90a4 Binary files /dev/null and b/input_example/hairs/00172_hair.png differ diff --git a/input_example/hairs/00174_hair.png b/input_example/hairs/00174_hair.png new file mode 100644 index 0000000..37fea23 Binary files /dev/null and b/input_example/hairs/00174_hair.png differ diff --git a/input_example/hairs/00175_hair.png b/input_example/hairs/00175_hair.png new file mode 100644 index 0000000..cc5ae5e Binary files /dev/null and b/input_example/hairs/00175_hair.png differ diff --git a/input_example/hairs/00176_hair.png b/input_example/hairs/00176_hair.png new file mode 100644 index 0000000..9001fdb Binary files /dev/null and b/input_example/hairs/00176_hair.png differ diff --git a/input_example/hairs/00177_hair.png b/input_example/hairs/00177_hair.png new file mode 100644 index 0000000..985229e Binary files /dev/null and b/input_example/hairs/00177_hair.png differ diff --git a/input_example/hairs/00178_hair.png b/input_example/hairs/00178_hair.png new file mode 100644 index 0000000..df57a71 Binary files /dev/null and b/input_example/hairs/00178_hair.png differ diff --git a/input_example/hairs/00180_hair.png b/input_example/hairs/00180_hair.png new file mode 100644 index 0000000..bb938eb Binary files /dev/null and b/input_example/hairs/00180_hair.png differ diff --git a/input_example/hairs/00181_hair.png b/input_example/hairs/00181_hair.png new file mode 100644 index 0000000..2fc111e Binary files /dev/null and b/input_example/hairs/00181_hair.png differ diff --git a/input_example/hairs/00182_hair.png b/input_example/hairs/00182_hair.png new file mode 100644 index 0000000..b855165 Binary files /dev/null and b/input_example/hairs/00182_hair.png differ diff --git a/input_example/hairs/00183_hair.png b/input_example/hairs/00183_hair.png new file mode 100644 index 0000000..95c64ed Binary files /dev/null and b/input_example/hairs/00183_hair.png differ diff --git a/input_example/hairs/00184_hair.png b/input_example/hairs/00184_hair.png new file mode 100644 index 0000000..60631e1 Binary files /dev/null and b/input_example/hairs/00184_hair.png differ diff --git a/input_example/hairs/00185_hair.png b/input_example/hairs/00185_hair.png new file mode 100644 index 0000000..7d71d6f Binary files /dev/null and b/input_example/hairs/00185_hair.png differ diff --git a/input_example/hairs/00186_hair.png b/input_example/hairs/00186_hair.png new file mode 100644 index 0000000..6ab0270 Binary files /dev/null and b/input_example/hairs/00186_hair.png differ diff --git a/input_example/hairs/00187_hair.png b/input_example/hairs/00187_hair.png new file mode 100644 index 0000000..0dee446 Binary files /dev/null and b/input_example/hairs/00187_hair.png differ diff --git a/input_example/hairs/00188_hair.png b/input_example/hairs/00188_hair.png new file mode 100644 index 0000000..1d50f2e Binary files /dev/null and b/input_example/hairs/00188_hair.png differ diff --git a/input_example/hairs/00189_hair.png b/input_example/hairs/00189_hair.png new file mode 100644 index 0000000..09bb288 Binary files /dev/null and b/input_example/hairs/00189_hair.png differ diff --git a/input_example/hairs/00191_hair.png b/input_example/hairs/00191_hair.png new file mode 100644 index 0000000..62531fe Binary files /dev/null and b/input_example/hairs/00191_hair.png differ diff --git a/input_example/hairs/00192_hair.png b/input_example/hairs/00192_hair.png new file mode 100644 index 0000000..bfea5dd Binary files /dev/null and b/input_example/hairs/00192_hair.png differ diff --git a/input_example/hairs/00193_hair.png b/input_example/hairs/00193_hair.png new file mode 100644 index 0000000..ab85c5b Binary files /dev/null and b/input_example/hairs/00193_hair.png differ diff --git a/input_example/hairs/00194_hair.png b/input_example/hairs/00194_hair.png new file mode 100644 index 0000000..43bfe57 Binary files /dev/null and b/input_example/hairs/00194_hair.png differ diff --git a/input_example/hairs/00195_hair.png b/input_example/hairs/00195_hair.png new file mode 100644 index 0000000..8dc9ea9 Binary files /dev/null and b/input_example/hairs/00195_hair.png differ diff --git a/input_example/hairs/00196_hair.png b/input_example/hairs/00196_hair.png new file mode 100644 index 0000000..6967855 Binary files /dev/null and b/input_example/hairs/00196_hair.png differ diff --git a/input_example/hairs/00197_hair.png b/input_example/hairs/00197_hair.png new file mode 100644 index 0000000..130b130 Binary files /dev/null and b/input_example/hairs/00197_hair.png differ diff --git a/input_example/skins/00000_skin.png b/input_example/skins/00000_skin.png new file mode 100644 index 0000000..98a457b Binary files /dev/null and b/input_example/skins/00000_skin.png differ diff --git a/input_example/skins/00001_skin.png b/input_example/skins/00001_skin.png new file mode 100644 index 0000000..78b7e89 Binary files /dev/null and b/input_example/skins/00001_skin.png differ diff --git a/input_example/skins/00002_skin.png b/input_example/skins/00002_skin.png new file mode 100644 index 0000000..f514ee6 Binary files /dev/null and b/input_example/skins/00002_skin.png differ diff --git a/input_example/skins/00003_skin.png b/input_example/skins/00003_skin.png new file mode 100644 index 0000000..ffba0a3 Binary files /dev/null and b/input_example/skins/00003_skin.png differ diff --git a/input_example/skins/00004_skin.png b/input_example/skins/00004_skin.png new file mode 100644 index 0000000..4e41b01 Binary files /dev/null and b/input_example/skins/00004_skin.png differ diff --git a/input_example/skins/00005_skin.png b/input_example/skins/00005_skin.png new file mode 100644 index 0000000..cc58e2c Binary files /dev/null and b/input_example/skins/00005_skin.png differ diff --git a/input_example/skins/00006_skin.png b/input_example/skins/00006_skin.png new file mode 100644 index 0000000..35621a5 Binary files /dev/null and b/input_example/skins/00006_skin.png differ diff --git a/input_example/skins/00007_skin.png b/input_example/skins/00007_skin.png new file mode 100644 index 0000000..4b76b1a Binary files /dev/null and b/input_example/skins/00007_skin.png differ diff --git a/input_example/skins/00008_skin.png b/input_example/skins/00008_skin.png new file mode 100644 index 0000000..34c9ef1 Binary files /dev/null and b/input_example/skins/00008_skin.png differ diff --git a/input_example/skins/00009_skin.png b/input_example/skins/00009_skin.png new file mode 100644 index 0000000..b6a4e15 Binary files /dev/null and b/input_example/skins/00009_skin.png differ diff --git a/input_example/skins/00010_skin.png b/input_example/skins/00010_skin.png new file mode 100644 index 0000000..e9e09ab Binary files /dev/null and b/input_example/skins/00010_skin.png differ diff --git a/input_example/skins/00011_skin.png b/input_example/skins/00011_skin.png new file mode 100644 index 0000000..eb38d26 Binary files /dev/null and b/input_example/skins/00011_skin.png differ diff --git a/input_example/skins/00012_skin.png b/input_example/skins/00012_skin.png new file mode 100644 index 0000000..8e7510f Binary files /dev/null and b/input_example/skins/00012_skin.png differ diff --git a/input_example/skins/00013_skin.png b/input_example/skins/00013_skin.png new file mode 100644 index 0000000..b3e75d4 Binary files /dev/null and b/input_example/skins/00013_skin.png differ diff --git a/input_example/skins/00014_skin.png b/input_example/skins/00014_skin.png new file mode 100644 index 0000000..6aaf34f Binary files /dev/null and b/input_example/skins/00014_skin.png differ diff --git a/input_example/skins/00015_skin.png b/input_example/skins/00015_skin.png new file mode 100644 index 0000000..3d8f09e Binary files /dev/null and b/input_example/skins/00015_skin.png differ diff --git a/input_example/skins/00016_skin.png b/input_example/skins/00016_skin.png new file mode 100644 index 0000000..b280deb Binary files /dev/null and b/input_example/skins/00016_skin.png differ diff --git a/input_example/skins/00017_skin.png b/input_example/skins/00017_skin.png new file mode 100644 index 0000000..69a87d8 Binary files /dev/null and b/input_example/skins/00017_skin.png differ diff --git a/input_example/skins/00018_skin.png b/input_example/skins/00018_skin.png new file mode 100644 index 0000000..e8d8100 Binary files /dev/null and b/input_example/skins/00018_skin.png differ diff --git a/input_example/skins/00019_skin.png b/input_example/skins/00019_skin.png new file mode 100644 index 0000000..935088b Binary files /dev/null and b/input_example/skins/00019_skin.png differ diff --git a/input_example/skins/00020_skin.png b/input_example/skins/00020_skin.png new file mode 100644 index 0000000..39513ef Binary files /dev/null and b/input_example/skins/00020_skin.png differ diff --git a/input_example/skins/00021_skin.png b/input_example/skins/00021_skin.png new file mode 100644 index 0000000..b03b417 Binary files /dev/null and b/input_example/skins/00021_skin.png differ diff --git a/input_example/skins/00022_skin.png b/input_example/skins/00022_skin.png new file mode 100644 index 0000000..ca7e8fd Binary files /dev/null and b/input_example/skins/00022_skin.png differ diff --git a/input_example/skins/00023_skin.png b/input_example/skins/00023_skin.png new file mode 100644 index 0000000..664056b Binary files /dev/null and b/input_example/skins/00023_skin.png differ diff --git a/input_example/skins/00024_skin.png b/input_example/skins/00024_skin.png new file mode 100644 index 0000000..5f4369a Binary files /dev/null and b/input_example/skins/00024_skin.png differ diff --git a/input_example/skins/00025_skin.png b/input_example/skins/00025_skin.png new file mode 100644 index 0000000..ee9f5c8 Binary files /dev/null and b/input_example/skins/00025_skin.png differ diff --git a/input_example/skins/00026_skin.png b/input_example/skins/00026_skin.png new file mode 100644 index 0000000..bd7bbcf Binary files /dev/null and b/input_example/skins/00026_skin.png differ diff --git a/input_example/skins/00027_skin.png b/input_example/skins/00027_skin.png new file mode 100644 index 0000000..5f44bae Binary files /dev/null and b/input_example/skins/00027_skin.png differ diff --git a/input_example/skins/00028_skin.png b/input_example/skins/00028_skin.png new file mode 100644 index 0000000..f3c665b Binary files /dev/null and b/input_example/skins/00028_skin.png differ diff --git a/input_example/skins/00029_skin.png b/input_example/skins/00029_skin.png new file mode 100644 index 0000000..d58bbdf Binary files /dev/null and b/input_example/skins/00029_skin.png differ diff --git a/input_example/skins/00030_skin.png b/input_example/skins/00030_skin.png new file mode 100644 index 0000000..6f0489b Binary files /dev/null and b/input_example/skins/00030_skin.png differ diff --git a/input_example/skins/00031_skin.png b/input_example/skins/00031_skin.png new file mode 100644 index 0000000..9deaff7 Binary files /dev/null and b/input_example/skins/00031_skin.png differ diff --git a/input_example/skins/00032_skin.png b/input_example/skins/00032_skin.png new file mode 100644 index 0000000..bfd0976 Binary files /dev/null and b/input_example/skins/00032_skin.png differ diff --git a/input_example/skins/00033_skin.png b/input_example/skins/00033_skin.png new file mode 100644 index 0000000..87cf7cb Binary files /dev/null and b/input_example/skins/00033_skin.png differ diff --git a/input_example/skins/00034_skin.png b/input_example/skins/00034_skin.png new file mode 100644 index 0000000..d256de1 Binary files /dev/null and b/input_example/skins/00034_skin.png differ diff --git a/input_example/skins/00035_skin.png b/input_example/skins/00035_skin.png new file mode 100644 index 0000000..3773251 Binary files /dev/null and b/input_example/skins/00035_skin.png differ diff --git a/input_example/skins/00036_skin.png b/input_example/skins/00036_skin.png new file mode 100644 index 0000000..8912354 Binary files /dev/null and b/input_example/skins/00036_skin.png differ diff --git a/input_example/skins/00037_skin.png b/input_example/skins/00037_skin.png new file mode 100644 index 0000000..0779e0a Binary files /dev/null and b/input_example/skins/00037_skin.png differ diff --git a/input_example/skins/00038_skin.png b/input_example/skins/00038_skin.png new file mode 100644 index 0000000..0011839 Binary files /dev/null and b/input_example/skins/00038_skin.png differ diff --git a/input_example/skins/00039_skin.png b/input_example/skins/00039_skin.png new file mode 100644 index 0000000..2dad61b Binary files /dev/null and b/input_example/skins/00039_skin.png differ diff --git a/input_example/skins/00040_skin.png b/input_example/skins/00040_skin.png new file mode 100644 index 0000000..a4f2b33 Binary files /dev/null and b/input_example/skins/00040_skin.png differ diff --git a/input_example/skins/00041_skin.png b/input_example/skins/00041_skin.png new file mode 100644 index 0000000..5421363 Binary files /dev/null and b/input_example/skins/00041_skin.png differ diff --git a/input_example/skins/00042_skin.png b/input_example/skins/00042_skin.png new file mode 100644 index 0000000..03581fc Binary files /dev/null and b/input_example/skins/00042_skin.png differ diff --git a/input_example/skins/00043_skin.png b/input_example/skins/00043_skin.png new file mode 100644 index 0000000..c35988c Binary files /dev/null and b/input_example/skins/00043_skin.png differ diff --git a/input_example/skins/00044_skin.png b/input_example/skins/00044_skin.png new file mode 100644 index 0000000..1a0029e Binary files /dev/null and b/input_example/skins/00044_skin.png differ diff --git a/input_example/skins/00045_skin.png b/input_example/skins/00045_skin.png new file mode 100644 index 0000000..2e6a7b0 Binary files /dev/null and b/input_example/skins/00045_skin.png differ diff --git a/input_example/skins/00046_skin.png b/input_example/skins/00046_skin.png new file mode 100644 index 0000000..2c90c84 Binary files /dev/null and b/input_example/skins/00046_skin.png differ diff --git a/input_example/skins/00047_skin.png b/input_example/skins/00047_skin.png new file mode 100644 index 0000000..5dd078b Binary files /dev/null and b/input_example/skins/00047_skin.png differ diff --git a/input_example/skins/00048_skin.png b/input_example/skins/00048_skin.png new file mode 100644 index 0000000..46b2fe6 Binary files /dev/null and b/input_example/skins/00048_skin.png differ diff --git a/input_example/skins/00049_skin.png b/input_example/skins/00049_skin.png new file mode 100644 index 0000000..ad82498 Binary files /dev/null and b/input_example/skins/00049_skin.png differ diff --git a/input_example/skins/00050_skin.png b/input_example/skins/00050_skin.png new file mode 100644 index 0000000..b13d6b4 Binary files /dev/null and b/input_example/skins/00050_skin.png differ diff --git a/input_example/skins/00051_skin.png b/input_example/skins/00051_skin.png new file mode 100644 index 0000000..25a43b1 Binary files /dev/null and b/input_example/skins/00051_skin.png differ diff --git a/input_example/skins/00052_skin.png b/input_example/skins/00052_skin.png new file mode 100644 index 0000000..4a7492d Binary files /dev/null and b/input_example/skins/00052_skin.png differ diff --git a/input_example/skins/00053_skin.png b/input_example/skins/00053_skin.png new file mode 100644 index 0000000..80df662 Binary files /dev/null and b/input_example/skins/00053_skin.png differ diff --git a/input_example/skins/00054_skin.png b/input_example/skins/00054_skin.png new file mode 100644 index 0000000..edb4025 Binary files /dev/null and b/input_example/skins/00054_skin.png differ diff --git a/input_example/skins/00055_skin.png b/input_example/skins/00055_skin.png new file mode 100644 index 0000000..35322d3 Binary files /dev/null and b/input_example/skins/00055_skin.png differ diff --git a/input_example/skins/00056_skin.png b/input_example/skins/00056_skin.png new file mode 100644 index 0000000..8bc9638 Binary files /dev/null and b/input_example/skins/00056_skin.png differ diff --git a/input_example/skins/00057_skin.png b/input_example/skins/00057_skin.png new file mode 100644 index 0000000..fbd06de Binary files /dev/null and b/input_example/skins/00057_skin.png differ diff --git a/input_example/skins/00058_skin.png b/input_example/skins/00058_skin.png new file mode 100644 index 0000000..03f1f27 Binary files /dev/null and b/input_example/skins/00058_skin.png differ diff --git a/input_example/skins/00059_skin.png b/input_example/skins/00059_skin.png new file mode 100644 index 0000000..9692ec2 Binary files /dev/null and b/input_example/skins/00059_skin.png differ diff --git a/input_example/skins/00060_skin.png b/input_example/skins/00060_skin.png new file mode 100644 index 0000000..79d98ef Binary files /dev/null and b/input_example/skins/00060_skin.png differ diff --git a/input_example/skins/00061_skin.png b/input_example/skins/00061_skin.png new file mode 100644 index 0000000..7c09b59 Binary files /dev/null and b/input_example/skins/00061_skin.png differ diff --git a/input_example/skins/00062_skin.png b/input_example/skins/00062_skin.png new file mode 100644 index 0000000..ff172c1 Binary files /dev/null and b/input_example/skins/00062_skin.png differ diff --git a/input_example/skins/00063_skin.png b/input_example/skins/00063_skin.png new file mode 100644 index 0000000..d171888 Binary files /dev/null and b/input_example/skins/00063_skin.png differ diff --git a/input_example/skins/00064_skin.png b/input_example/skins/00064_skin.png new file mode 100644 index 0000000..d1927a5 Binary files /dev/null and b/input_example/skins/00064_skin.png differ diff --git a/input_example/skins/00065_skin.png b/input_example/skins/00065_skin.png new file mode 100644 index 0000000..605e8d3 Binary files /dev/null and b/input_example/skins/00065_skin.png differ diff --git a/input_example/skins/00066_skin.png b/input_example/skins/00066_skin.png new file mode 100644 index 0000000..3c17df8 Binary files /dev/null and b/input_example/skins/00066_skin.png differ diff --git a/input_example/skins/00067_skin.png b/input_example/skins/00067_skin.png new file mode 100644 index 0000000..586c65d Binary files /dev/null and b/input_example/skins/00067_skin.png differ diff --git a/input_example/skins/00068_skin.png b/input_example/skins/00068_skin.png new file mode 100644 index 0000000..c71221c Binary files /dev/null and b/input_example/skins/00068_skin.png differ diff --git a/input_example/skins/00069_skin.png b/input_example/skins/00069_skin.png new file mode 100644 index 0000000..b92ada6 Binary files /dev/null and b/input_example/skins/00069_skin.png differ diff --git a/input_example/skins/00070_skin.png b/input_example/skins/00070_skin.png new file mode 100644 index 0000000..a34b92f Binary files /dev/null and b/input_example/skins/00070_skin.png differ diff --git a/input_example/skins/00071_skin.png b/input_example/skins/00071_skin.png new file mode 100644 index 0000000..ae4f89d Binary files /dev/null and b/input_example/skins/00071_skin.png differ diff --git a/input_example/skins/00072_skin.png b/input_example/skins/00072_skin.png new file mode 100644 index 0000000..456c68d Binary files /dev/null and b/input_example/skins/00072_skin.png differ diff --git a/input_example/skins/00073_skin.png b/input_example/skins/00073_skin.png new file mode 100644 index 0000000..87d41cf Binary files /dev/null and b/input_example/skins/00073_skin.png differ diff --git a/input_example/skins/00074_skin.png b/input_example/skins/00074_skin.png new file mode 100644 index 0000000..cdf8fad Binary files /dev/null and b/input_example/skins/00074_skin.png differ diff --git a/input_example/skins/00075_skin.png b/input_example/skins/00075_skin.png new file mode 100644 index 0000000..0ce7708 Binary files /dev/null and b/input_example/skins/00075_skin.png differ diff --git a/input_example/skins/00076_skin.png b/input_example/skins/00076_skin.png new file mode 100644 index 0000000..af0699c Binary files /dev/null and b/input_example/skins/00076_skin.png differ diff --git a/input_example/skins/00077_skin.png b/input_example/skins/00077_skin.png new file mode 100644 index 0000000..cc4a15d Binary files /dev/null and b/input_example/skins/00077_skin.png differ diff --git a/input_example/skins/00078_skin.png b/input_example/skins/00078_skin.png new file mode 100644 index 0000000..9517312 Binary files /dev/null and b/input_example/skins/00078_skin.png differ diff --git a/input_example/skins/00079_skin.png b/input_example/skins/00079_skin.png new file mode 100644 index 0000000..bbca138 Binary files /dev/null and b/input_example/skins/00079_skin.png differ diff --git a/input_example/skins/00080_skin.png b/input_example/skins/00080_skin.png new file mode 100644 index 0000000..9aca377 Binary files /dev/null and b/input_example/skins/00080_skin.png differ diff --git a/input_example/skins/00081_skin.png b/input_example/skins/00081_skin.png new file mode 100644 index 0000000..221b0b6 Binary files /dev/null and b/input_example/skins/00081_skin.png differ diff --git a/input_example/skins/00082_skin.png b/input_example/skins/00082_skin.png new file mode 100644 index 0000000..e7404d1 Binary files /dev/null and b/input_example/skins/00082_skin.png differ diff --git a/input_example/skins/00083_skin.png b/input_example/skins/00083_skin.png new file mode 100644 index 0000000..4092894 Binary files /dev/null and b/input_example/skins/00083_skin.png differ diff --git a/input_example/skins/00084_skin.png b/input_example/skins/00084_skin.png new file mode 100644 index 0000000..eefca59 Binary files /dev/null and b/input_example/skins/00084_skin.png differ diff --git a/input_example/skins/00085_skin.png b/input_example/skins/00085_skin.png new file mode 100644 index 0000000..7434b9d Binary files /dev/null and b/input_example/skins/00085_skin.png differ diff --git a/input_example/skins/00086_skin.png b/input_example/skins/00086_skin.png new file mode 100644 index 0000000..79751e2 Binary files /dev/null and b/input_example/skins/00086_skin.png differ diff --git a/input_example/skins/00087_skin.png b/input_example/skins/00087_skin.png new file mode 100644 index 0000000..8f080cc Binary files /dev/null and b/input_example/skins/00087_skin.png differ diff --git a/input_example/skins/00088_skin.png b/input_example/skins/00088_skin.png new file mode 100644 index 0000000..115091c Binary files /dev/null and b/input_example/skins/00088_skin.png differ diff --git a/input_example/skins/00089_skin.png b/input_example/skins/00089_skin.png new file mode 100644 index 0000000..34052c4 Binary files /dev/null and b/input_example/skins/00089_skin.png differ diff --git a/input_example/skins/00090_skin.png b/input_example/skins/00090_skin.png new file mode 100644 index 0000000..36311f1 Binary files /dev/null and b/input_example/skins/00090_skin.png differ diff --git a/input_example/skins/00091_skin.png b/input_example/skins/00091_skin.png new file mode 100644 index 0000000..94fc7a0 Binary files /dev/null and b/input_example/skins/00091_skin.png differ diff --git a/input_example/skins/00092_skin.png b/input_example/skins/00092_skin.png new file mode 100644 index 0000000..953dae8 Binary files /dev/null and b/input_example/skins/00092_skin.png differ diff --git a/input_example/skins/00093_skin.png b/input_example/skins/00093_skin.png new file mode 100644 index 0000000..9579f25 Binary files /dev/null and b/input_example/skins/00093_skin.png differ diff --git a/input_example/skins/00094_skin.png b/input_example/skins/00094_skin.png new file mode 100644 index 0000000..d8f0c52 Binary files /dev/null and b/input_example/skins/00094_skin.png differ diff --git a/input_example/skins/00095_skin.png b/input_example/skins/00095_skin.png new file mode 100644 index 0000000..b2006f5 Binary files /dev/null and b/input_example/skins/00095_skin.png differ diff --git a/input_example/skins/00096_skin.png b/input_example/skins/00096_skin.png new file mode 100644 index 0000000..8ea0db3 Binary files /dev/null and b/input_example/skins/00096_skin.png differ diff --git a/input_example/skins/00097_skin.png b/input_example/skins/00097_skin.png new file mode 100644 index 0000000..30ce8f1 Binary files /dev/null and b/input_example/skins/00097_skin.png differ diff --git a/input_example/skins/00098_skin.png b/input_example/skins/00098_skin.png new file mode 100644 index 0000000..5e7d991 Binary files /dev/null and b/input_example/skins/00098_skin.png differ diff --git a/input_example/skins/00099_skin.png b/input_example/skins/00099_skin.png new file mode 100644 index 0000000..f43abdb Binary files /dev/null and b/input_example/skins/00099_skin.png differ diff --git a/input_example/skins/00100_skin.png b/input_example/skins/00100_skin.png new file mode 100644 index 0000000..80cf4c9 Binary files /dev/null and b/input_example/skins/00100_skin.png differ diff --git a/input_example/skins/00101_skin.png b/input_example/skins/00101_skin.png new file mode 100644 index 0000000..9676faa Binary files /dev/null and b/input_example/skins/00101_skin.png differ diff --git a/input_example/skins/00102_skin.png b/input_example/skins/00102_skin.png new file mode 100644 index 0000000..2eb98eb Binary files /dev/null and b/input_example/skins/00102_skin.png differ diff --git a/input_example/skins/00103_skin.png b/input_example/skins/00103_skin.png new file mode 100644 index 0000000..420fc8a Binary files /dev/null and b/input_example/skins/00103_skin.png differ diff --git a/input_example/skins/00104_skin.png b/input_example/skins/00104_skin.png new file mode 100644 index 0000000..8fab08e Binary files /dev/null and b/input_example/skins/00104_skin.png differ diff --git a/input_example/skins/00105_skin.png b/input_example/skins/00105_skin.png new file mode 100644 index 0000000..dbccee9 Binary files /dev/null and b/input_example/skins/00105_skin.png differ diff --git a/input_example/skins/00106_skin.png b/input_example/skins/00106_skin.png new file mode 100644 index 0000000..3529549 Binary files /dev/null and b/input_example/skins/00106_skin.png differ diff --git a/input_example/skins/00107_skin.png b/input_example/skins/00107_skin.png new file mode 100644 index 0000000..9e802fa Binary files /dev/null and b/input_example/skins/00107_skin.png differ diff --git a/input_example/skins/00108_skin.png b/input_example/skins/00108_skin.png new file mode 100644 index 0000000..8af44ed Binary files /dev/null and b/input_example/skins/00108_skin.png differ diff --git a/input_example/skins/00109_skin.png b/input_example/skins/00109_skin.png new file mode 100644 index 0000000..4c1d0e7 Binary files /dev/null and b/input_example/skins/00109_skin.png differ diff --git a/input_example/skins/00110_skin.png b/input_example/skins/00110_skin.png new file mode 100644 index 0000000..fe40a58 Binary files /dev/null and b/input_example/skins/00110_skin.png differ diff --git a/input_example/skins/00111_skin.png b/input_example/skins/00111_skin.png new file mode 100644 index 0000000..b0b145a Binary files /dev/null and b/input_example/skins/00111_skin.png differ diff --git a/input_example/skins/00112_skin.png b/input_example/skins/00112_skin.png new file mode 100644 index 0000000..7f4c992 Binary files /dev/null and b/input_example/skins/00112_skin.png differ diff --git a/input_example/skins/00113_skin.png b/input_example/skins/00113_skin.png new file mode 100644 index 0000000..3fb9450 Binary files /dev/null and b/input_example/skins/00113_skin.png differ diff --git a/input_example/skins/00114_skin.png b/input_example/skins/00114_skin.png new file mode 100644 index 0000000..8d7b9ea Binary files /dev/null and b/input_example/skins/00114_skin.png differ diff --git a/input_example/skins/00115_skin.png b/input_example/skins/00115_skin.png new file mode 100644 index 0000000..c8f720a Binary files /dev/null and b/input_example/skins/00115_skin.png differ diff --git a/input_example/skins/00116_skin.png b/input_example/skins/00116_skin.png new file mode 100644 index 0000000..45039f3 Binary files /dev/null and b/input_example/skins/00116_skin.png differ diff --git a/input_example/skins/00117_skin.png b/input_example/skins/00117_skin.png new file mode 100644 index 0000000..0cdbf57 Binary files /dev/null and b/input_example/skins/00117_skin.png differ diff --git a/input_example/skins/00118_skin.png b/input_example/skins/00118_skin.png new file mode 100644 index 0000000..cad7e7b Binary files /dev/null and b/input_example/skins/00118_skin.png differ diff --git a/input_example/skins/00119_skin.png b/input_example/skins/00119_skin.png new file mode 100644 index 0000000..3a2a7de Binary files /dev/null and b/input_example/skins/00119_skin.png differ diff --git a/input_example/skins/00120_skin.png b/input_example/skins/00120_skin.png new file mode 100644 index 0000000..3dfa3c4 Binary files /dev/null and b/input_example/skins/00120_skin.png differ diff --git a/input_example/skins/00121_skin.png b/input_example/skins/00121_skin.png new file mode 100644 index 0000000..6aa17d3 Binary files /dev/null and b/input_example/skins/00121_skin.png differ diff --git a/input_example/skins/00122_skin.png b/input_example/skins/00122_skin.png new file mode 100644 index 0000000..ad24deb Binary files /dev/null and b/input_example/skins/00122_skin.png differ diff --git a/input_example/skins/00123_skin.png b/input_example/skins/00123_skin.png new file mode 100644 index 0000000..c998ba5 Binary files /dev/null and b/input_example/skins/00123_skin.png differ diff --git a/input_example/skins/00124_skin.png b/input_example/skins/00124_skin.png new file mode 100644 index 0000000..460cb24 Binary files /dev/null and b/input_example/skins/00124_skin.png differ diff --git a/input_example/skins/00125_skin.png b/input_example/skins/00125_skin.png new file mode 100644 index 0000000..f2e3381 Binary files /dev/null and b/input_example/skins/00125_skin.png differ diff --git a/input_example/skins/00126_skin.png b/input_example/skins/00126_skin.png new file mode 100644 index 0000000..bb984b9 Binary files /dev/null and b/input_example/skins/00126_skin.png differ diff --git a/input_example/skins/00127_skin.png b/input_example/skins/00127_skin.png new file mode 100644 index 0000000..ed3d478 Binary files /dev/null and b/input_example/skins/00127_skin.png differ diff --git a/input_example/skins/00128_skin.png b/input_example/skins/00128_skin.png new file mode 100644 index 0000000..7ffcd39 Binary files /dev/null and b/input_example/skins/00128_skin.png differ diff --git a/input_example/skins/00129_skin.png b/input_example/skins/00129_skin.png new file mode 100644 index 0000000..1ea9e56 Binary files /dev/null and b/input_example/skins/00129_skin.png differ diff --git a/input_example/skins/00130_skin.png b/input_example/skins/00130_skin.png new file mode 100644 index 0000000..029f227 Binary files /dev/null and b/input_example/skins/00130_skin.png differ diff --git a/input_example/skins/00131_skin.png b/input_example/skins/00131_skin.png new file mode 100644 index 0000000..b3d241c Binary files /dev/null and b/input_example/skins/00131_skin.png differ diff --git a/input_example/skins/00132_skin.png b/input_example/skins/00132_skin.png new file mode 100644 index 0000000..dbc2b91 Binary files /dev/null and b/input_example/skins/00132_skin.png differ diff --git a/input_example/skins/00133_skin.png b/input_example/skins/00133_skin.png new file mode 100644 index 0000000..3ddfbe0 Binary files /dev/null and b/input_example/skins/00133_skin.png differ diff --git a/input_example/skins/00134_skin.png b/input_example/skins/00134_skin.png new file mode 100644 index 0000000..40ec045 Binary files /dev/null and b/input_example/skins/00134_skin.png differ diff --git a/input_example/skins/00135_skin.png b/input_example/skins/00135_skin.png new file mode 100644 index 0000000..a3bb3d8 Binary files /dev/null and b/input_example/skins/00135_skin.png differ diff --git a/input_example/skins/00136_skin.png b/input_example/skins/00136_skin.png new file mode 100644 index 0000000..fdede8c Binary files /dev/null and b/input_example/skins/00136_skin.png differ diff --git a/input_example/skins/00137_skin.png b/input_example/skins/00137_skin.png new file mode 100644 index 0000000..bbeb8ae Binary files /dev/null and b/input_example/skins/00137_skin.png differ diff --git a/input_example/skins/00138_skin.png b/input_example/skins/00138_skin.png new file mode 100644 index 0000000..c656588 Binary files /dev/null and b/input_example/skins/00138_skin.png differ diff --git a/input_example/skins/00139_skin.png b/input_example/skins/00139_skin.png new file mode 100644 index 0000000..7f2567f Binary files /dev/null and b/input_example/skins/00139_skin.png differ diff --git a/input_example/skins/00140_skin.png b/input_example/skins/00140_skin.png new file mode 100644 index 0000000..633f721 Binary files /dev/null and b/input_example/skins/00140_skin.png differ diff --git a/input_example/skins/00141_skin.png b/input_example/skins/00141_skin.png new file mode 100644 index 0000000..ba9bd38 Binary files /dev/null and b/input_example/skins/00141_skin.png differ diff --git a/input_example/skins/00142_skin.png b/input_example/skins/00142_skin.png new file mode 100644 index 0000000..ca179d6 Binary files /dev/null and b/input_example/skins/00142_skin.png differ diff --git a/input_example/skins/00143_skin.png b/input_example/skins/00143_skin.png new file mode 100644 index 0000000..63f3d8f Binary files /dev/null and b/input_example/skins/00143_skin.png differ diff --git a/input_example/skins/00144_skin.png b/input_example/skins/00144_skin.png new file mode 100644 index 0000000..9f54095 Binary files /dev/null and b/input_example/skins/00144_skin.png differ diff --git a/input_example/skins/00145_skin.png b/input_example/skins/00145_skin.png new file mode 100644 index 0000000..9424ea3 Binary files /dev/null and b/input_example/skins/00145_skin.png differ diff --git a/input_example/skins/00146_skin.png b/input_example/skins/00146_skin.png new file mode 100644 index 0000000..2e34f4f Binary files /dev/null and b/input_example/skins/00146_skin.png differ diff --git a/input_example/skins/00147_skin.png b/input_example/skins/00147_skin.png new file mode 100644 index 0000000..6c6da06 Binary files /dev/null and b/input_example/skins/00147_skin.png differ diff --git a/input_example/skins/00148_skin.png b/input_example/skins/00148_skin.png new file mode 100644 index 0000000..e9275d8 Binary files /dev/null and b/input_example/skins/00148_skin.png differ diff --git a/input_example/skins/00149_skin.png b/input_example/skins/00149_skin.png new file mode 100644 index 0000000..beb3338 Binary files /dev/null and b/input_example/skins/00149_skin.png differ diff --git a/input_example/skins/00150_skin.png b/input_example/skins/00150_skin.png new file mode 100644 index 0000000..21c1107 Binary files /dev/null and b/input_example/skins/00150_skin.png differ diff --git a/input_example/skins/00151_skin.png b/input_example/skins/00151_skin.png new file mode 100644 index 0000000..6b56de3 Binary files /dev/null and b/input_example/skins/00151_skin.png differ diff --git a/input_example/skins/00152_skin.png b/input_example/skins/00152_skin.png new file mode 100644 index 0000000..9971a3c Binary files /dev/null and b/input_example/skins/00152_skin.png differ diff --git a/input_example/skins/00153_skin.png b/input_example/skins/00153_skin.png new file mode 100644 index 0000000..8247de8 Binary files /dev/null and b/input_example/skins/00153_skin.png differ diff --git a/input_example/skins/00154_skin.png b/input_example/skins/00154_skin.png new file mode 100644 index 0000000..9c0072c Binary files /dev/null and b/input_example/skins/00154_skin.png differ diff --git a/input_example/skins/00155_skin.png b/input_example/skins/00155_skin.png new file mode 100644 index 0000000..b30355d Binary files /dev/null and b/input_example/skins/00155_skin.png differ diff --git a/input_example/skins/00156_skin.png b/input_example/skins/00156_skin.png new file mode 100644 index 0000000..59050a2 Binary files /dev/null and b/input_example/skins/00156_skin.png differ diff --git a/input_example/skins/00157_skin.png b/input_example/skins/00157_skin.png new file mode 100644 index 0000000..9bff6e8 Binary files /dev/null and b/input_example/skins/00157_skin.png differ diff --git a/input_example/skins/00158_skin.png b/input_example/skins/00158_skin.png new file mode 100644 index 0000000..9ab669c Binary files /dev/null and b/input_example/skins/00158_skin.png differ diff --git a/input_example/skins/00159_skin.png b/input_example/skins/00159_skin.png new file mode 100644 index 0000000..7844e5d Binary files /dev/null and b/input_example/skins/00159_skin.png differ diff --git a/input_example/skins/00160_skin.png b/input_example/skins/00160_skin.png new file mode 100644 index 0000000..55ae045 Binary files /dev/null and b/input_example/skins/00160_skin.png differ diff --git a/input_example/skins/00161_skin.png b/input_example/skins/00161_skin.png new file mode 100644 index 0000000..eb6d42b Binary files /dev/null and b/input_example/skins/00161_skin.png differ diff --git a/input_example/skins/00162_skin.png b/input_example/skins/00162_skin.png new file mode 100644 index 0000000..6e2b649 Binary files /dev/null and b/input_example/skins/00162_skin.png differ diff --git a/input_example/skins/00163_skin.png b/input_example/skins/00163_skin.png new file mode 100644 index 0000000..17e9e9e Binary files /dev/null and b/input_example/skins/00163_skin.png differ diff --git a/input_example/skins/00164_skin.png b/input_example/skins/00164_skin.png new file mode 100644 index 0000000..125ecf5 Binary files /dev/null and b/input_example/skins/00164_skin.png differ diff --git a/input_example/skins/00165_skin.png b/input_example/skins/00165_skin.png new file mode 100644 index 0000000..2ac10ba Binary files /dev/null and b/input_example/skins/00165_skin.png differ diff --git a/input_example/skins/00166_skin.png b/input_example/skins/00166_skin.png new file mode 100644 index 0000000..dc2cd76 Binary files /dev/null and b/input_example/skins/00166_skin.png differ diff --git a/input_example/skins/00167_skin.png b/input_example/skins/00167_skin.png new file mode 100644 index 0000000..025fd7d Binary files /dev/null and b/input_example/skins/00167_skin.png differ diff --git a/input_example/skins/00388_skin.png b/input_example/skins/00388_skin.png new file mode 100644 index 0000000..bb39fde Binary files /dev/null and b/input_example/skins/00388_skin.png differ diff --git a/input_example/skins/00449_skin.png b/input_example/skins/00449_skin.png new file mode 100644 index 0000000..542161c Binary files /dev/null and b/input_example/skins/00449_skin.png differ diff --git a/input_example/skins/00451_skin.png b/input_example/skins/00451_skin.png new file mode 100644 index 0000000..d50a244 Binary files /dev/null and b/input_example/skins/00451_skin.png differ diff --git a/losses/Full_loss.py b/losses/Full_loss.py new file mode 100644 index 0000000..bfaa3b6 --- /dev/null +++ b/losses/Full_loss.py @@ -0,0 +1,32 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import lpips +import numpy as np +import torch +import torchvision.transforms.functional as F1 +import yaml + +from networks.face_parsing_losses.parse_losses import Total_Faceparseloss +from networks.FARL_losses.farl_losses import Total_farlloss +from networks.vgg_face.perceptual import Total_VGGloss as VGGfaceNetwork + + +class Full_loss(torch.nn.Module): + def __init__(self, args): + super(Full_loss, self).__init__() + + dtype = torch.cuda.FloatTensor + self.args = args + self.vggface_loss = VGGfaceNetwork(args["networks"]["VGGface"]) + self.parsingloss = Total_Faceparseloss(args["networks"]["Semantics"]) + self.farl_loss = Total_farlloss(args["networks"]["FARL"]) + + def forward(self, pred_img, identity_img, t, diffusion=None): + loss = 0 + loss = loss + self.vggface_loss(pred_img, identity_img, t) + loss = loss + self.farl_loss(pred_img, t, identity_img, diffusion) + loss = loss + self.parsingloss(pred_img, identity_img, t) + + return loss diff --git a/losses/ssim.py b/losses/ssim.py new file mode 100644 index 0000000..c4927aa --- /dev/null +++ b/losses/ssim.py @@ -0,0 +1,87 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# Copyright (C) 2023 Po-Hsun-Su + +# SPDX-License-Identifier: AGPL-3.0-or-later +# SPDX-License-Identifier: MIT + +# Code taken from https://github.com/Po-Hsun-Su/pytorch-ssim/tree/master -- MIT License + + +from math import exp + +import numpy as np +import torch +import torch.nn.functional as F +from torch.autograd import Variable + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + + +def ssim(img1, img2, window_size=11, size_average=True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) diff --git a/networks/FARL_losses/farl_losses.py b/networks/FARL_losses/farl_losses.py new file mode 100644 index 0000000..a381584 --- /dev/null +++ b/networks/FARL_losses/farl_losses.py @@ -0,0 +1,90 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# SPDX-License-Identifier: AGPL-3.0-or-later + +import clip +import torch +import torch.nn as nn +import torchvision.transforms.functional as F1 +from PIL import Image +from torch.nn import functional as F +from torchvision import transforms + + +def d_clip_loss(x, y, use_cosine=True): + x = F.normalize(x, dim=-1) + y = F.normalize(y, dim=-1) + + if use_cosine: + distance = 1 - (x @ y.t()).squeeze() + else: + distance = (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) + + return distance + + +def find_cossim(x, y): + sim_fun = torch.nn.CosineSimilarity(dim=1, eps=1e-08) + return sim_fun(x, y) + + +class Total_farlloss(nn.Module): + def __init__(self, args): + super(Total_farlloss, self).__init__() + self.args = args + device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = clip.load("ViT-B/16", device=device, jit=False)[0].eval().requires_grad_(False) + self.model = self.model.to(device) + farl_state = torch.load( + self.args["checkpoint"] + ) # you can download from https://github.com/FacePerceiver/FaRL#pre-trained-backbones + self.model.load_state_dict(farl_state["state_dict"], strict=False) + self.clip_normalize = transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] + ) + self.clip_size = self.model.visual.input_resolution + + self.model.eval() + + def forward_image(self, x_in): + x_in = x_in.add(1).div(2) + + clip_loss = torch.tensor(0) + clip_in = self.clip_normalize(x_in) + clip_in = F1.resize(clip_in, [self.clip_size, self.clip_size]) + image_embeds = self.model.encode_image(clip_in).float() + return image_embeds + + def forward(self, pred_img, t, gt=None, diffusion=None): + loss = 0 + b = pred_img.shape[0] + clip_loss = 0 + id_loss = 0 + if self.args["farlclip"]["use"]: + if t <= self.args["farlclip"]["max_t"] and t >= self.args["farlclip"]["min_t"]: + pred_image_embed = self.forward_image(pred_img) + text_embed = self.model.encode_text(clip.tokenize(self.args["farlclip"]["prompt"]).to("cuda:0")).float() + clip_loss_full = d_clip_loss(pred_image_embed, text_embed) + clip_loss = clip_loss + clip_loss_full.mean() + loss = loss + clip_loss * self.args["farlclip"]["lambda"] + if self.args["farledit"]["use"]: + if t <= self.args["farledit"]["max_t"] and t >= self.args["farledit"]["min_t"]: + noise = torch.randn_like(gt) + gt_noisy = gt + pred_image_embed = self.forward_image(pred_img) + gt_image_embed = self.forward_image(gt_noisy) + + text_embed = self.model.encode_text(clip.tokenize(self.args["farledit"]["prompt"]).to("cuda:0")).float() + clip_loss_full = d_clip_loss(pred_image_embed, text_embed) + clip_loss = clip_loss + clip_loss_full.mean() + loss = loss + clip_loss * self.args["farledit"]["lambda"] + if self.args["farlidentity"]["use"]: + + if t <= self.args["farlidentity"]["max_t"] and t >= self.args["farlidentity"]["min_t"]: + pred_image_embed = self.forward_image(pred_img) + gt_image_embed = self.forward_image(gt) + id_loss_full = d_clip_loss(pred_image_embed, gt_image_embed) + + id_loss = id_loss + id_loss_full.mean() + + loss = loss + self.args["farlidentity"]["lambda"] * id_loss + return loss diff --git a/networks/FARL_losses/fp_farl_loss.py b/networks/FARL_losses/fp_farl_loss.py new file mode 100644 index 0000000..8df2b2a --- /dev/null +++ b/networks/FARL_losses/fp_farl_loss.py @@ -0,0 +1,220 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# SPDX-License-Identifier: AGPL-3.0-or-later +import sys + +import torch + +sys.path.append("..") +device = "cuda" +import os + +import cv2 +import facer +import numpy as np +import torch.nn.functional as F +from PIL import Image +from torch import nn +from torch.nn.functional import cross_entropy +from torch.nn.modules.loss import _WeightedLoss +from torchvision.transforms import functional as TF + +label_map = { + "background": 0, + "skin": 1, + "left_eyebrow": 2, + "right_eyebrow": 3, + "left_eye": 4, + "right_eye": 5, + "nose": 6, + "upper_lip": 7, + "inner_mouth": 8, + "lower_lip": 9, + "hair": 10, +} + + +class LogNLLLoss(_WeightedLoss): + __constants__ = ["weight", "reduction", "ignore_index"] + + def __init__(self, weight=None, size_average=None, reduce=None, reduction=None, ignore_index=-100): + super(LogNLLLoss, self).__init__(weight, size_average, reduce, reduction) + self.ignore_index = ignore_index + + def forward(self, y_input, y_target): + return cross_entropy(y_input, y_target, weight=self.weight, ignore_index=self.ignore_index) + + +class parsefacesegment_faces(torch.nn.Module): + def __init__(self, criterion="nn.BCEWithLogitsLoss", label_idx="hair", save=False): + super(parsefacesegment_faces, self).__init__() + + self.label_idx = label_map[label_idx] + self.save = False + self.face_detector = facer.face_detector("retinaface/mobilenet", device=device) + self.face_parser = facer.face_parser("farl/lapa/448", device=device) + self.face_detector.eval() + self.face_parser.eval() + + if criterion == "LogNLLLoss": + self.loss = LogNLLLoss() + elif criterion == "nn.BCEWithLogitsLoss": + self.loss = nn.BCEWithLogitsLoss() + + def forward(self, pred_img, gt, save_fold="./parsed/", thres=0.9): + idx = self.label_idx + pred_img = (pred_img + 1) / 2 + pred_img = pred_img + pred_img = pred_img * 255.0 + gt = (gt + 1) / 2 + gt = gt.clip(0, 1) + gt = gt * 255.0 + gt_clone = torch.clone(gt) + gtfaces1 = self.face_detector(gt_clone) + + with torch.inference_mode(): + gtfaces = self.face_parser(gt, gtfaces1) + + gtseg_logits = gtfaces["seg"]["logits"] + gtout = gtseg_logits.softmax(dim=1) # nfaces x nclasses x h x w + + faces = self.face_parser(pred_img, gtfaces1) + + seg_logits = faces["seg"]["logits"] + out = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w + + a, b, c, d = out.shape + if self.save: + for i in range(out.shape[1]): + parsed = out[0, i, :, :] + parsed = parsed.detach().cpu().numpy() + parsed = np.uint8(parsed * 255.0) + fold = save_fold + if os.path.exists(fold) == False: + os.mkdir(fold) + cv2.imwrite(fold + str(i) + ".png", parsed) + + parsing = out[:, :, :, :] + gtparsing = gtout[:, :, :, :] + loss = self.loss(gtparsing, parsing) + return loss, parsing, gtparsing + + +class parsefaceloss_faces(torch.nn.Module): + def __init__(self, criterion="nn.BCEWithLogitsLoss", label_idx="hair", save=True): + super(parsefaceloss_faces, self).__init__() + + self.label_idx = label_map[label_idx] + self.save = save + self.face_detector = facer.face_detector("retinaface/mobilenet", device=device) + self.face_parser = facer.face_parser("farl/lapa/448", device=device) + self.face_detector.eval() + self.face_parser.eval() + + if criterion == "LogNLLLoss": + self.loss = LogNLLLoss() + elif criterion == "nn.BCEWithLogitsLoss": + self.loss = nn.BCEWithLogitsLoss() + + def forward(self, pred_img, gt, save_fold="./parsed/", thres=0.9): + idx = self.label_idx + pred_img = (pred_img + 1) / 2 + pred_img = pred_img + pred_img = pred_img * 255.0 + gt = (gt + 1) / 2 + gt = gt.clip(0, 1) + gt = gt * 255.0 + gt_clone = torch.clone(gt) + gtfaces1 = self.face_detector(gt_clone) + + with torch.inference_mode(): + gtfaces = self.face_parser(gt, gtfaces1) + + gtseg_logits = gtfaces["seg"]["logits"] + gtout = gtseg_logits.softmax(dim=1) # nfaces x nclasses x h x w + + faces = self.face_parser(pred_img, gtfaces1) + + seg_logits = faces["seg"]["logits"] + out = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w + + a, b, c, d = out.shape + if self.save: + for i in range(out.shape[1]): + parsed = out[0, i, :, :] + parsed = parsed.detach().cpu().numpy() + parsed = np.uint8(parsed * 255.0) + fold = save_fold + if os.path.exists(fold) == False: + os.mkdir(fold) + cv2.imwrite(fold + str(i) + ".png", parsed) + + gt_mask = gtout[0, idx].detach().cpu().numpy() + gt_mask = np.uint8(gt_mask.clip(0, 1) * 255.0) + cv2.imwrite(fold + "gt_mask.png", gt_mask) + gt_img = gt[0].permute(1, 2, 0).detach().cpu().numpy() # .clip(0,255.0) + cv2.imwrite(fold + "gt_img.png", np.uint8(gt_img[:, :, ::-1])) + gtparsing = gtout[:, idx, :, :] + gtparsing = gtparsing.unsqueeze(1) + masked_gt = gt * gtparsing + masked_pred = pred_img * gtparsing + gt_img = masked_gt[0].permute(1, 2, 0).detach().cpu().numpy().clip(0, 255.0) + cv2.imwrite(fold + "gt_masked.png", gt_img[:, :, ::-1]) + parsing = out[:, idx, :, :] + gtparsing = gtout[:, idx, :, :] + loss = self.loss(gtparsing, parsing) + parsing = parsing.unsqueeze(1) + gtparsing = gtparsing.unsqueeze(1) + return loss, parsing, gtparsing + + +class parsefaceloss(torch.nn.Module): + def __init__(self, criterion="nn.BCEWithLogitsLoss"): + super(parsefaceloss, self).__init__() + self.face_detector = facer.face_detector("retinaface/mobilenet", device=device) + self.face_parser = facer.face_parser("farl/lapa/448", device=device) + + if criterion == "LogNLLLoss": + self.loss = LogNLLLoss() + elif criterion == "nn.BCEWithLogitsLoss": + self.loss = nn.BCEWithLogitsLoss() + + def forward(self, pred_img, gt, save=True, save_fold="./parsed/", thres=0.9): + pred_img = (pred_img + 1) / 2 + pred_img = pred_img # .clip(0,1) + img = pred_img * 255.0 + with torch.inference_mode(): + faces = self.face_detector(img) + + with torch.inference_mode(): + faces = self.face_parser(img, faces) + + seg_logits = faces["seg"]["logits"] + out = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w + print(out.shape) + a, b, c, d = out.shape + if save: + for i in range(out.shape[1]): + parsed = out[0, i, :, :] + parsed = parsed.detach().cpu().numpy() + parsed = np.uint8(parsed * 255.0) + fold = save_fold + if os.path.exists(fold) == False: + os.mkdir(fold) + cv2.imwrite(fold + str(i) + ".png", parsed) + gt_mask = gt[0, 0].detach().cpu().numpy() + gt_mask = np.uint8(gt_mask.clip(0, 1) * 255.0) + cv2.imwrite(fold + "gt_mask.png", gt_mask) + + parsing = out[:, 10, :, :].view(a, 1, c, d) + gt = gt.repeat(a, 1, 1, 1) + loss = self.loss(parsing, gt) + return loss + + +if __name__ == "__main__": + + init_image_pil_transfer = Image.open("./18.jpg").convert("RGB") + init_image_pil_transfer = init_image_pil_transfer.resize((256, 256), Image.BICUBIC) # type: ignore + init_image_transfer = TF.to_tensor(init_image_pil_transfer).cuda().unsqueeze(0).mul(2).sub(1) + faceparser = parsefaceloss_faces() + faceparser(init_image_transfer, init_image_transfer) diff --git a/networks/face_parsing_losses/parse_losses.py b/networks/face_parsing_losses/parse_losses.py new file mode 100644 index 0000000..3abffd1 --- /dev/null +++ b/networks/face_parsing_losses/parse_losses.py @@ -0,0 +1,107 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# SPDX-License-Identifier: AGPL-3.0-or-later + +import sys + +import torch + +sys.path.append("..") +device = "cuda" +import os + +import cv2 +import facer +import numpy as np +import torch.nn.functional as F +from PIL import Image +from torch import nn +from torch.nn.functional import cross_entropy +from torch.nn.modules.loss import _WeightedLoss +from torchvision.transforms import functional as TF + + +class LogNLLLoss(_WeightedLoss): + __constants__ = ["weight", "reduction", "ignore_index"] + + def __init__(self, weight=None, size_average=None, reduce=None, reduction=None, ignore_index=-100): + super(LogNLLLoss, self).__init__(weight, size_average, reduce, reduction) + self.ignore_index = ignore_index + + def forward(self, y_input, y_target): + # y_input = torch.log(y_input + EPSILON) + return cross_entropy(y_input, y_target, weight=self.weight, ignore_index=self.ignore_index) + + +class Total_Faceparseloss(nn.Module): + def __init__(self, args): + super(Total_Faceparseloss, self).__init__() + device = "cuda" + dtype = torch.cuda.FloatTensor + self.args = args + self.device = device + self.save = False + self.face_detector = facer.face_detector("retinaface/mobilenet", device=device) + self.face_parser = facer.face_parser("farl/lapa/448", device=device) + # self.face_detector.eval() + # self.face_parser.eval() + + parse_loss_criterion = args["criterion"] + + if parse_loss_criterion == "LogNLLLoss": + self.loss = LogNLLLoss() + + elif parse_loss_criterion == "nn.BCEWithLogitsLoss": + self.loss = nn.BCEWithLogitsLoss() + + def detect_faces(self, det_input): + det_input = (det_input + 1) / 2 + det_input = det_input.clip(0, 1) + det_input = det_input * 255.0 + det_clone = torch.clone(det_input) + det_faces = self.face_detector(det_clone) + return det_faces + + def parse_faces(self, det_input, det_faces, setgt=False): + det_input = (det_input + 1) / 2 + det_input = det_input.clip(0, 1) + det_input = det_input * 255.0 + if setgt: + with torch.inference_mode(): + parsefaces = self.face_parser(det_input, det_faces) + else: + parsefaces = self.face_parser(det_input, det_faces) + + parse_logits = parsefaces["seg"]["logits"] + parsed_outputs = parse_logits.softmax(dim=1) # nfaces x nclasses x h x w + + return parsed_outputs + + def forward(self, pred_img, gt_segment, t): + gt_segment = torch.clone(gt_segment).detach() + a, b, c, d = pred_img.shape + + loss = 0 + + if self.args["face_segment_parse"]["use"]: + det_faces = self.detect_faces(gt_segment) + if t[0] <= self.args["face_segment_parse"]["max_t"] and t[0] >= self.args["face_segment_parse"]["min_t"]: + inp_segment = pred_img + gt_segment = torch.clone(gt_segment) + parsed_input = self.parse_faces(inp_segment, det_faces) + parsed_gt = self.parse_faces(gt_segment, det_faces, setgt=True) + req_input = parsed_input + req_gt = parsed_gt + loss_entropy = self.args["face_segment_parse"]["lambda"] * self.loss(req_input, req_gt) + loss = loss + loss_entropy + + return loss + + +if __name__ == "__main__": + + init_image_pil_transfer = Image.open("./18.jpg").convert("RGB") + init_image_pil_transfer = init_image_pil_transfer.resize((256, 256), Image.BICUBIC) # type: ignore + init_image_transfer = TF.to_tensor(init_image_pil_transfer).cuda().unsqueeze(0).mul(2).sub(1) + faceparser = parsefaceloss_faces() + # img[img>0]=1 + faceparser(init_image_transfer, init_image_transfer) diff --git a/networks/face_parsing_losses/parsing_losses.py b/networks/face_parsing_losses/parsing_losses.py new file mode 100644 index 0000000..4a81cc3 --- /dev/null +++ b/networks/face_parsing_losses/parsing_losses.py @@ -0,0 +1,223 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# SPDX-License-Identifier: AGPL-3.0-or-later + +import sys + +import torch + +sys.path.append("..") +device = "cuda" +import os + +import cv2 +import facer +import numpy as np +import torch.nn.functional as F +from PIL import Image +from torch import nn +from torch.nn.functional import cross_entropy +from torch.nn.modules.loss import _WeightedLoss +from torchvision.transforms import functional as TF + +label_map = { + "background": 0, + "skin": 1, + "left_eyebrow": 2, + "right_eyebrow": 3, + "left_eye": 4, + "right_eye": 5, + "nose": 6, + "upper_lip": 7, + "inner_mouth": 8, + "lower_lip": 9, + "hair": 10, +} + + +class LogNLLLoss(_WeightedLoss): + __constants__ = ["weight", "reduction", "ignore_index"] + + def __init__(self, weight=None, size_average=None, reduce=None, reduction=None, ignore_index=-100): + super(LogNLLLoss, self).__init__(weight, size_average, reduce, reduction) + self.ignore_index = ignore_index + + def forward(self, y_input, y_target): + # y_input = torch.log(y_input + EPSILON) + return cross_entropy(y_input, y_target, weight=self.weight, ignore_index=self.ignore_index) + + +class parsefacesegment_faces(torch.nn.Module): + def __init__(self, criterion="nn.BCEWithLogitsLoss", label_idx="hair", save=False): + super(parsefacesegment_faces, self).__init__() + + self.label_idx = label_map[label_idx] + self.save = False + self.face_detector = facer.face_detector("retinaface/mobilenet", device=device) + self.face_parser = facer.face_parser("farl/lapa/448", device=device) + self.face_detector.eval() + self.face_parser.eval() + + if criterion == "LogNLLLoss": + self.loss = LogNLLLoss() + elif criterion == "nn.BCEWithLogitsLoss": + self.loss = nn.BCEWithLogitsLoss() + + def forward(self, pred_img, gt, save_fold="./parsed/", thres=0.9): + idx = self.label_idx + pred_img = (pred_img + 1) / 2 + pred_img = pred_img + pred_img = pred_img * 255.0 + gt = (gt + 1) / 2 + gt = gt.clip(0, 1) + gt = gt * 255.0 + gt_clone = torch.clone(gt) + gtfaces1 = self.face_detector(gt_clone) + + with torch.inference_mode(): + gtfaces = self.face_parser(gt, gtfaces1) + + gtseg_logits = gtfaces["seg"]["logits"] + gtout = gtseg_logits.softmax(dim=1) # nfaces x nclasses x h x w + + faces = self.face_parser(pred_img, gtfaces1) + + seg_logits = faces["seg"]["logits"] + out = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w + + a, b, c, d = out.shape + if self.save: + for i in range(out.shape[1]): + parsed = out[0, i, :, :] + parsed = parsed.detach().cpu().numpy() + parsed = np.uint8(parsed * 255.0) + fold = save_fold + if os.path.exists(fold) == False: + os.mkdir(fold) + cv2.imwrite(fold + str(i) + ".png", parsed) + + parsing = out[:, :, :, :] + gtparsing = gtout[:, :, :, :] + loss = self.loss(gtparsing, parsing) + return loss, parsing, gtparsing + + +class parsefaceloss_faces(torch.nn.Module): + def __init__(self, criterion="nn.BCEWithLogitsLoss", label_idx="hair", save=True): + super(parsefaceloss_faces, self).__init__() + + self.label_idx = label_map[label_idx] + self.save = save + self.face_detector = facer.face_detector("retinaface/mobilenet", device=device) + self.face_parser = facer.face_parser("farl/lapa/448", device=device) + self.face_detector.eval() + self.face_parser.eval() + + if criterion == "LogNLLLoss": + self.loss = LogNLLLoss() + elif criterion == "nn.BCEWithLogitsLoss": + self.loss = nn.BCEWithLogitsLoss() + + def forward(self, pred_img, gt, save_fold="./parsed/", thres=0.9): + idx = self.label_idx + pred_img = (pred_img + 1) / 2 + pred_img = pred_img + pred_img = pred_img * 255.0 + gt = (gt + 1) / 2 + gt = gt.clip(0, 1) + gt = gt * 255.0 + gt_clone = torch.clone(gt) + gtfaces1 = self.face_detector(gt_clone) + + with torch.inference_mode(): + gtfaces = self.face_parser(gt, gtfaces1) + + gtseg_logits = gtfaces["seg"]["logits"] + gtout = gtseg_logits.softmax(dim=1) # nfaces x nclasses x h x w + + faces = self.face_parser(pred_img, gtfaces1) + + seg_logits = faces["seg"]["logits"] + out = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w + + a, b, c, d = out.shape + if self.save: + for i in range(out.shape[1]): + parsed = out[0, i, :, :] + parsed = parsed.detach().cpu().numpy() + parsed = np.uint8(parsed * 255.0) + fold = save_fold + if os.path.exists(fold) == False: + os.mkdir(fold) + cv2.imwrite(fold + str(i) + ".png", parsed) + + gt_mask = gtout[0, idx].detach().cpu().numpy() + gt_mask = np.uint8(gt_mask.clip(0, 1) * 255.0) + cv2.imwrite(fold + "gt_mask.png", gt_mask) + gt_img = gt[0].permute(1, 2, 0).detach().cpu().numpy() + cv2.imwrite(fold + "gt_img.png", np.uint8(gt_img[:, :, ::-1])) + gtparsing = gtout[:, idx, :, :] + gtparsing = gtparsing.unsqueeze(1) + masked_gt = gt * gtparsing + masked_pred = pred_img * gtparsing + gt_img = masked_gt[0].permute(1, 2, 0).detach().cpu().numpy().clip(0, 255.0) + cv2.imwrite(fold + "gt_masked.png", gt_img[:, :, ::-1]) + parsing = out[:, idx, :, :] + gtparsing = gtout[:, idx, :, :] + loss = self.loss(gtparsing, parsing) + # print(loss) + parsing = parsing.unsqueeze(1) + gtparsing = gtparsing.unsqueeze(1) + return loss, parsing, gtparsing + + +class parsefaceloss(torch.nn.Module): + def __init__(self, criterion="nn.BCEWithLogitsLoss"): + super(parsefaceloss, self).__init__() + self.face_detector = facer.face_detector("retinaface/mobilenet", device=device) + self.face_parser = facer.face_parser("farl/lapa/448", device=device) + + if criterion == "LogNLLLoss": + self.loss = LogNLLLoss() + elif criterion == "nn.BCEWithLogitsLoss": + self.loss = nn.BCEWithLogitsLoss() + + def forward(self, pred_img, gt, save=True, save_fold="./parsed/", thres=0.9): + pred_img = (pred_img + 1) / 2 + pred_img = pred_img # .clip(0,1) + img = pred_img * 255.0 + with torch.inference_mode(): + faces = self.face_detector(img) + + with torch.inference_mode(): + faces = self.face_parser(img, faces) + + seg_logits = faces["seg"]["logits"] + out = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w + print(out.shape) + a, b, c, d = out.shape + if save: + for i in range(out.shape[1]): + parsed = out[0, i, :, :] + parsed = parsed.detach().cpu().numpy() + parsed = np.uint8(parsed * 255.0) + fold = save_fold + if os.path.exists(fold) == False: + os.mkdir(fold) + cv2.imwrite(fold + str(i) + ".png", parsed) + gt_mask = gt[0, 0].detach().cpu().numpy() + gt_mask = np.uint8(gt_mask.clip(0, 1) * 255.0) + cv2.imwrite(fold + "gt_mask.png", gt_mask) + + parsing = out[:, 10, :, :].view(a, 1, c, d) + gt = gt.repeat(a, 1, 1, 1) + loss = self.loss(parsing, gt) + return loss + + +if __name__ == "__main__": + + init_image_pil_transfer = Image.open("./18.jpg").convert("RGB") + init_image_pil_transfer = init_image_pil_transfer.resize((256, 256), Image.BICUBIC) # type: ignore + init_image_transfer = TF.to_tensor(init_image_pil_transfer).cuda().unsqueeze(0).mul(2).sub(1) + faceparser = parsefaceloss_faces() + faceparser(init_image_transfer, init_image_transfer) diff --git a/networks/vgg_face/perceptual.py b/networks/vgg_face/perceptual.py new file mode 100644 index 0000000..cd1dbd8 --- /dev/null +++ b/networks/vgg_face/perceptual.py @@ -0,0 +1,147 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def gray_resize_for_identity(out, size=128): + # print(out.shape) + out_gray = 0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :] + out_gray = out_gray.unsqueeze(1).repeat(1, 3, 1, 1) + # out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False) + return out_gray + + +class Vgg16(torch.nn.Module): + def __init__(self): + super(Vgg16, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + + self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + + self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + + self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + def forward(self, X): + X1 = gray_resize_for_identity(X) + h = F.relu(self.conv1_1(X1)) + h = F.relu(self.conv1_2(h)) + relu1_2 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv2_1(h)) + h = F.relu(self.conv2_2(h)) + relu2_2 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv3_1(h)) + h = F.relu(self.conv3_2(h)) + h = F.relu(self.conv3_3(h)) + relu3_3 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv4_1(h)) + h = F.relu(self.conv4_2(h)) + h = F.relu(self.conv4_3(h)) + relu4_3 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv5_1(h)) + h = F.relu(self.conv5_2(h)) + h = F.relu(self.conv5_3(h)) + relu5_3 = h + + return [relu1_2, relu2_2, relu3_3] + + +def load_vgg(checkpoint): + vgg = Vgg16() + state_dict_g = torch.load(checkpoint) + new_state_dict_g = {} + new_state_dict_g["conv1_1.weight"] = state_dict_g["0.weight"] + new_state_dict_g["conv1_1.bias"] = state_dict_g["0.bias"] + new_state_dict_g["conv1_2.weight"] = state_dict_g["2.weight"] + new_state_dict_g["conv1_2.bias"] = state_dict_g["2.bias"] + new_state_dict_g["conv2_1.weight"] = state_dict_g["5.weight"] + new_state_dict_g["conv2_1.bias"] = state_dict_g["5.bias"] + new_state_dict_g["conv2_2.weight"] = state_dict_g["7.weight"] + new_state_dict_g["conv2_2.bias"] = state_dict_g["7.bias"] + new_state_dict_g["conv3_1.weight"] = state_dict_g["10.weight"] + new_state_dict_g["conv3_1.bias"] = state_dict_g["10.bias"] + new_state_dict_g["conv3_2.weight"] = state_dict_g["12.weight"] + new_state_dict_g["conv3_2.bias"] = state_dict_g["12.bias"] + new_state_dict_g["conv3_3.weight"] = state_dict_g["14.weight"] + new_state_dict_g["conv3_3.bias"] = state_dict_g["14.bias"] + new_state_dict_g["conv4_1.weight"] = state_dict_g["17.weight"] + new_state_dict_g["conv4_1.bias"] = state_dict_g["17.bias"] + new_state_dict_g["conv4_2.weight"] = state_dict_g["19.weight"] + new_state_dict_g["conv4_2.bias"] = state_dict_g["19.bias"] + new_state_dict_g["conv4_3.weight"] = state_dict_g["21.weight"] + new_state_dict_g["conv4_3.bias"] = state_dict_g["21.bias"] + new_state_dict_g["conv5_1.weight"] = state_dict_g["24.weight"] + new_state_dict_g["conv5_1.bias"] = state_dict_g["24.bias"] + new_state_dict_g["conv5_2.weight"] = state_dict_g["26.weight"] + new_state_dict_g["conv5_2.bias"] = state_dict_g["26.bias"] + new_state_dict_g["conv5_3.weight"] = state_dict_g["28.weight"] + new_state_dict_g["conv5_3.bias"] = state_dict_g["28.bias"] + vgg.load_state_dict(new_state_dict_g) + return vgg + + +class Total_VGGloss(nn.Module): + def __init__(self, args): + super(Total_VGGloss, self).__init__() + self.args = args + self.vgg_model = load_vgg(args["checkpoint"]) + self.vgg_model.cuda() + self.vgg_model.eval() + self.args = args + + def forward_network(self, pred_img, gt): + loss = [] + pred_img_features = self.vgg_model(pred_img) + gt_features = self.vgg_model(gt) + for pred_img_feature, gt_feature in zip(pred_img_features, gt_features): + loss.append(F.mse_loss(pred_img_feature, gt_feature)) + + return loss # sum(loss)/len(loss) + + def forward(self, pred_img, gt, t): + loss = 0 + + use_VGG = self.args["multiscale"]["use"] or self.args["singlescale"]["use"] + if self.args["multiscale"]["use"]: + min_range = self.args["multiscale"]["min_t"] + max_range = self.args["multiscale"]["max_t"] + if self.args["singlescale"]["use"]: + min_range = self.args["singlescale"]["min_t"] + max_range = self.args["singlescale"]["max_t"] + + if use_VGG: + if t <= max_range and t >= min_range: + loss_val = self.forward_network(pred_img, gt) + else: + return loss + + if t <= max_range and t >= min_range: + if self.args["multiscale"]["use"]: + loss_multi = sum(loss_val) / len(loss_val) + loss = loss_multi * self.args["multiscale"]["lambda"] + + if self.args["singlescale"]["use"]: + loss = loss + loss_val[-1] * self.args["singlescale"]["lambda"] + + return loss diff --git a/parser.py b/parser.py new file mode 100644 index 0000000..5373994 --- /dev/null +++ b/parser.py @@ -0,0 +1,33 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# Copyright (C) 2022-2023 Omri Avrahami + +# SPDX-License-Identifier: AGPL-3.0-or-later +# SPDX-License-Identifier: MIT + +# Code adapted from https://github.com/omriav/blended-diffusion -- MIT License + +import argparse + + +def get_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-config", "--config", type=str, help="Config file with generations", default="configs/diffusion_config.yml" + ) + parser.add_argument( + "-img_path", "--img_path", type=str, help="Path of example image", default="./input_example/faces/4.jpg" + ) + parser.add_argument( + "-mask_path", "--mask_path", type=str, help="Path of example mask", default="./input_example/masks/4.png" + ) + parser.add_argument("-data_fold", "--data_fold", type=str, help="Path of data fold", default="./data") + parser.add_argument("-condition", "--condition", type=str, help="Required condition", default="grayscale") + parser.add_argument( + "-editing_text", + "--editing_text", + type=str, + help="Required text for editing", + default="A woman with blonde hair", + ) + args = parser.parse_args() + return args diff --git a/steered_diffusion.py b/steered_diffusion.py new file mode 100644 index 0000000..759b990 --- /dev/null +++ b/steered_diffusion.py @@ -0,0 +1,132 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# Copyright (C) 2022-2023 Omri Avrahami + +# SPDX-License-Identifier: AGPL-3.0-or-later +# SPDX-License-Identifier: MIT + +# Code adapted from https://github.com/omriav/blended-diffusion -- MIT License + +import os +from pathlib import Path + +import lpips +import numpy as np +import torch +import torch.nn as nn +import torchvision.transforms.functional as F +import tqdm +import yaml +from numpy import random +from PIL import Image, ImageOps +from torchvision.transforms import functional as TF +from tqdm import tqdm + +from guided_diffusion.guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults +from losses.ssim import SSIM + +torch.autograd.set_detect_anomaly(True) + +from losses.Full_loss import Full_loss + + +class ImageEditor: + def __init__(self, args) -> None: + self.args = args + self.data = args["data"] + self.params = args["params"] + self.network = args["diffusion_network"] + self.checkpoints = args["checkpoints"] + self.Full_loss = Full_loss(args) + count = 0 + out_path = os.path.join(self.params["results_dir"]) + if os.path.exists(out_path) == False: + os.makedirs(out_path) + self.data["output_path"] = out_path + + if self.args["seed"] is not None: + torch.manual_seed(self.args["seed"]) + np.random.seed(self.args["seed"]) + random.seed(self.args["seed"]) + + self.model_config = model_and_diffusion_defaults() + + self.model_config.update(self.network) + gpu_id = self.args["gpu_id"] + self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") + print("Using device:", self.device) + + self.model, self.diffusion = create_model_and_diffusion(**self.model_config) + self.model.load_state_dict(torch.load(self.checkpoints["ffhq"])) + self.model.requires_grad_(False).eval().to(self.device) + for name, param in self.model.named_parameters(): + if "qkv" in name or "norm" in name or "proj" in name: + param.requires_grad_() + if self.model_config["use_fp16"]: + self.model.convert_to_fp16() + + self.image_size = (self.params["image_size"], self.params["image_size"]) + + def edit_image(self): + + if self.params["use_ddim"]: + self.init_ddim = self.diffusion.ddim_reverse_sample_loop( + self.model, + self.init_image_transfer, + )["sample"] + + batch = self.params["batch_size"] + img = self.data["init_image"] + init_image = Image.open(img).convert("RGB") + init_image = init_image.resize(self.image_size, Image.BICUBIC) + if self.params["cond"] == "inpaint": + mask_image = Image.open(self.data["init_mask"]).convert("L") + mask_image = mask_image.resize(self.image_size, Image.BICUBIC) + + init_image = TF.to_tensor(init_image).to(self.device).unsqueeze(0).mul(2).sub(1) + self.init_image = init_image + + if self.params["cond"] == "inpaint": + mask_image = TF.to_tensor(mask_image).to(self.device).unsqueeze(0) + mask_image = mask_image.repeat(1, 3, 1, 1) + else: + mask_image = None + image_name = img.split("/")[-1].strip(".jpg") + model_kwargs = { + "cond": self.params["cond"], + "mask_image": mask_image, + "init_image": self.init_image, + "num_iters": 1, + "factor": self.params["scale_factor"], + } + + shape = ( + batch, + 3, + self.model_config["image_size"], + self.model_config["image_size"], + ) + + model_kwargs["dest_fold"] = os.path.join("./results", model_kwargs["cond"]) + + samples = self.diffusion.conditional_sample_loop_progressive( + model=self.model, + shape=shape, + clip_denoised=False, + model_kwargs=model_kwargs, + noise=None, + cond_fn=self.Full_loss, + progress=True, + ) + + for count, sample in enumerate(samples): + pred_image = sample["sample"] + pred_image = pred_image.add(1).div(2).clamp(0, 1) + degraded_image = sample["degraded"].add(1).div(2).clamp(0, 1) + dest_fold = os.path.join(self.params["results_dir"], model_kwargs["cond"], image_name) + if os.path.exists(dest_fold) == False: + os.makedirs(dest_fold) + for j in range(pred_image.shape[0]): + degraded_pred = torch.cat([degraded_image[j], pred_image[j]], dim=2) + pred_image_pil = TF.to_pil_image(degraded_pred) + pred_path = os.path.join(dest_fold, str(j) + ".jpg") + pred_image_pil.save(pred_path) diff --git a/steered_diffusion_dataset.py b/steered_diffusion_dataset.py new file mode 100644 index 0000000..9ed5871 --- /dev/null +++ b/steered_diffusion_dataset.py @@ -0,0 +1,142 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# Copyright (C) 2022-2023 Omri Avrahami + +# SPDX-License-Identifier: AGPL-3.0-or-later +# SPDX-License-Identifier: MIT + +# Code adapted from https://github.com/omriav/blended-diffusion -- MIT License + + +import glob +import os +from pathlib import Path + +import clip +import lpips +import numpy as np +import torch +import torch.nn as nn +import torchvision.transforms.functional as F +import tqdm +import yaml +from numpy import random +from PIL import Image, ImageOps +from torchvision.transforms import functional as TF +from tqdm import tqdm + +from guided_diffusion.guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults +from losses.ssim import SSIM + +torch.autograd.set_detect_anomaly(True) + +from losses.Full_loss import Full_loss + + +class ImageEditor: + def __init__(self, args) -> None: + self.args = args + self.data = args["data"] + self.params = args["params"] + self.network = args["diffusion_network"] + self.checkpoints = args["checkpoints"] + self.Full_loss = Full_loss(args) + count = 0 + out_path = os.path.join(self.params["results_dir"]) + if os.path.exists(out_path) == False: + os.makedirs(out_path) + self.data["output_path"] = out_path + + if self.args["seed"] is not None: + torch.manual_seed(self.args["seed"]) + np.random.seed(self.args["seed"]) + random.seed(self.args["seed"]) + + self.model_config = model_and_diffusion_defaults() + + self.model_config.update(self.network) + gpu_id = self.args["gpu_id"] + self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") + print("Using device:", self.device) + + self.model, self.diffusion = create_model_and_diffusion(**self.model_config) + self.model.load_state_dict(torch.load(self.checkpoints["ffhq"])) + self.model.requires_grad_(False).eval().to(self.device) + for name, param in self.model.named_parameters(): + if "qkv" in name or "norm" in name or "proj" in name: + param.requires_grad_() + if self.model_config["use_fp16"]: + self.model.convert_to_fp16() + + self.image_size = (self.params["image_size"], self.params["image_size"]) + + def edit_image(self): + + if self.params["use_ddim"]: + self.init_ddim = self.diffusion.ddim_reverse_sample_loop( + self.model, + self.init_image_transfer, + )["sample"] + + batch = self.params["batch_size"] + + images = glob.glob(os.path.join(self.data["data_fold"], "images", "*")) + if self.params["cond"] == "inpaint": + masks = glob.glob(os.path.join(self.data["data_fold"], "masks", "*")) + else: + masks = [None] * len(images) + + for img, mask in zip(images, masks): + init_image = Image.open(img).convert("RGB") + init_image = init_image.resize(self.image_size, Image.BICUBIC) + + init_image = TF.to_tensor(init_image).to(self.device).unsqueeze(0).mul(2).sub(1) + self.init_image = init_image + + if self.params["cond"] == "inpaint": + mask_image = Image.open(mask).convert("L") + mask_image = mask_image.resize(self.image_size, Image.BICUBIC) + mask_image = TF.to_tensor(mask_image).to(self.device).unsqueeze(0) + mask_image = mask_image.repeat(1, 3, 1, 1) + else: + mask_image = None + + image_name = img.split("/")[-1].strip(".jpg") + model_kwargs = { + "cond": self.params["cond"], + "mask_image": mask_image, + "init_image": self.init_image, + "num_iters": 1, + "factor": self.params["scale_factor"], + } + + shape = ( + batch, + 3, + self.model_config["image_size"], + self.model_config["image_size"], + ) + + model_kwargs["dest_fold"] = os.path.join("./results", model_kwargs["cond"]) + + samples = self.diffusion.conditional_sample_loop_progressive( + model=self.model, + shape=shape, + clip_denoised=False, + model_kwargs=model_kwargs, + noise=None, + cond_fn=self.Full_loss, + progress=True, + ) + + for count, sample in enumerate(samples): + pred_image = sample["sample"] + pred_image = pred_image.add(1).div(2).clamp(0, 1) + degraded_image = sample["degraded"].add(1).div(2).clamp(0, 1) + dest_fold = os.path.join(self.params["results_dir"], model_kwargs["cond"], image_name) + if os.path.exists(dest_fold) == False: + os.makedirs(dest_fold) + for j in range(pred_image.shape[0]): + degraded_pred = torch.cat([degraded_image[j], pred_image[j]], dim=2) + pred_image_pil = TF.to_pil_image(degraded_pred) + pred_path = os.path.join(dest_fold, str(j) + ".jpg") + pred_image_pil.save(pred_path) diff --git a/steered_generate.py b/steered_generate.py new file mode 100644 index 0000000..3322edf --- /dev/null +++ b/steered_generate.py @@ -0,0 +1,32 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) + +# SPDX-License-Identifier: AGPL-3.0-or-later + + +from parser import get_arguments + +import yaml + +from steered_diffusion import ImageEditor + +if __name__ == "__main__": + args_config = get_arguments() + config = args_config.config + + args_yaml_file = open(config) + args = yaml.load(args_yaml_file, Loader=yaml.FullLoader) + args["data"]["init_image"] = args_config.img_path + args["data"]["init_mask"] = args_config.mask_path + args["data"]["data_fold"] = args_config.data_fold + args["params"]["cond"] = args_config.condition + if args_config.condition == "Semantics": + args["networks"]["Semantics"]["face_segment_parse"]["use"] = True + elif args_config.condition == "Identity": + args["networks"]["FARL"]["farlidentity"]["use"] = True + elif args_config.condition == "editing": + args["networks"]["FARL"]["farledit"]["use"] = True + args["networks"]["VGGface"]["multiscale"]["use"] = True + args["networks"]["FARL"]["farledit"]["prompt"] = args_config.editing_text + + image_editor = ImageEditor(args) + image_editor.edit_image() diff --git a/steered_generate_dataset.py b/steered_generate_dataset.py new file mode 100644 index 0000000..5e4d34d --- /dev/null +++ b/steered_generate_dataset.py @@ -0,0 +1,32 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: AGPL-3.0-or-later + + +from parser import get_arguments + +import yaml + +from steered_diffusion_dataset import ImageEditor + +if __name__ == "__main__": + args_config = get_arguments() + config = args_config.config + + args_yaml_file = open(config) + args = yaml.load(args_yaml_file, Loader=yaml.FullLoader) + args["data"]["init_image"] = args_config.img_path + args["data"]["init_mask"] = args_config.mask_path + args["data"]["data_fold"] = args_config.data_fold + args["params"]["cond"] = args_config.condition + if args_config.condition == "Semantics": + args["networks"]["Semantics"]["face_segment_parse"]["use"] = True + elif args_config.condition == "Identity": + args["networks"]["FARL"]["farlidentity"]["use"] = True + elif args_config.condition == "editing": + args["networks"]["FARL"]["farledit"]["use"] = True + args["networks"]["VGGface"]["multiscale"]["use"] = True + args["networks"]["FARL"]["farledit"]["prompt"] = args_config.editing_text + + image_editor = ImageEditor(args) + image_editor.edit_image() diff --git a/utils/download_models.py b/utils/download_models.py new file mode 100644 index 0000000..1882f7f --- /dev/null +++ b/utils/download_models.py @@ -0,0 +1,13 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# Copyright (C) 2022-2023 Nithin Gopalakrishnan Nair + + +# SPDX-License-Identifier: AGPL-3.0-or-later +# SPDX-License-Identifier: Apache-2.0 + +# Code adapted from https://github.com/Nithin-GK/UniteandConquer/blob/main/download_models.py -- Apache-2.0 license + +import numpy as np +from download_models_func import download_files + +download_files() diff --git a/utils/download_models_func.py b/utils/download_models_func.py new file mode 100644 index 0000000..97f83f1 --- /dev/null +++ b/utils/download_models_func.py @@ -0,0 +1,137 @@ +# Copyright (C) 2023-2024 Mitsubishi Electric Research Laboratories (MERL) +# Copyright (C) 2022-2023 Nithin Gopalakrishnan Nair +# Copyright (C) 2021-2022 OpenAi + + +# SPDX-License-Identifier: AGPL-3.0-or-later +# SPDX-License-Identifier: Apache-2.0 +# SPDX-License-Identifier: MIT + +# Code adapted from https://github.com/openai/glide-text2im/blob/main/glide_text2im/download.py -- MIT License +# Code adapted from https://github.com/Nithin-GK/UniteandConquer/blob/main/download_models.py -- Apache-2.0 license + +import os +from functools import lru_cache +from typing import Dict, Optional + +import requests +from filelock import FileLock +from tqdm.auto import tqdm + + +@lru_cache() +def default_cache_dir(): + return os.path.join(os.path.abspath(os.getcwd()), "checkpoints") + + +MODEL_PATHS = { + "model_face": "https://www.dropbox.com/scl/fi/jcv8a178943o10ml02f3r/ffhq_10m.pt?rlkey=o3nl8gpbg24l49uk1z3xmdpv4&dl=1", + "arcface": "https://www.dropbox.com/scl/fi/8yf5tw71xbdf6a7nyzg0a/arcface18.pth?rlkey=9qa4e4y1digdvmnt7huzxyjie&dl=1", + "farl_clip": "https://www.dropbox.com/scl/fi/6xwjn5amuu2zyjpbaxz5q/FaRL-Base-Patch16-LAIONFace20M-ep64.pth?rlkey=jszbu9zbmq5euyj97xjdp4bnk&dl=1", + "farl_parse": "https://www.dropbox.com/scl/fi/fa3mmuom0sagg7b6x61gb/face_parse.pth?rlkey=4c45rtoydue5xyb5bkg36iam8&dl=1", + "vggface": "https://www.dropbox.com/scl/fi/se50l3z1iaafccxiksf1r/VGG_FACE.pth?rlkey=8pc9m8na7cxlfv2wdme7fecqn&dl=1", + "imagenet_diffusion": "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt", +} + + +LOCAL_PATHS = { + "model_face": "./checkpoints/ffhq_10m.pt", + "farl_clip": "./checkpoints/FaRL-Base-Patch16-LAIONFace20M-ep64.pth", + "farl_parse": "./checkpoints/face_parse.pth", + "arcface": "./checkpoints/arcface18.pth", + "vggface": "./checkpoints/VGG_FACE.pth", + "imagenet_diffusion": "./checkpoints/diffusion256x256.pt", +} + +if os.path.exists("./checkpoints") == False: + os.mkdir("./checkpoints") +# taken from this StackOverflow answer: https://stackoverflow.com/a/39225039 +import requests + + +def download_file_from_google_drive(id, destination): + URL = "https://docs.google.com/uc?export=download" + + session = requests.Session() + + response = session.get(URL, params={"id": id}, stream=True) + token = get_confirm_token(response) + + if token: + params = {"id": id, "confirm": token} + response = session.get(URL, params=params, stream=True) + + save_response_content(response, destination) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith("download_warning"): + return value + + return None + + +def save_response_content(response, destination): + CHUNK_SIZE = 32768 + + with open(destination, "wb") as f: + for chunk in response.iter_content(CHUNK_SIZE): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + + +def download_file_from_google_drive(id, destination): + URL = "https://docs.google.com/uc?export=download" + + session = requests.Session() + + response = session.get(URL, params={"id": id}, stream=True) + token = get_confirm_token(response) + + if token: + params = {"id": id, "confirm": token} + response = session.get(URL, params=params, stream=True) + + save_response_content(response, destination) + + +def fetch_file_cached( + url: str, key: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096 +) -> str: + """ + Download the file at the given URL into a local file and return the path. + If cache_dir is specified, it will be used to download the files. + Otherwise, default_cache_dir() is used. + """ + if cache_dir is None: + cache_dir = default_cache_dir() + os.makedirs(cache_dir, exist_ok=True) + local_path = LOCAL_PATHS[key] + print(local_path) + if os.path.exists(local_path): + return LOCAL_PATHS[key] + response = requests.get(url, stream=True) + size = int(response.headers.get("content-length", "0")) + with FileLock(local_path + ".lock"): + if progress: + pbar = tqdm(total=size, unit="iB", unit_scale=True) + tmp_path = local_path + ".tmp" + with open(tmp_path, "wb") as f: + for chunk in response.iter_content(chunk_size): + if progress: + pbar.update(len(chunk)) + f.write(chunk) + os.rename(tmp_path, local_path) + if progress: + pbar.close() + return local_path + + +def download_files(): + for _ in MODEL_PATHS: + model = fetch_file_cached(MODEL_PATHS[_], _) + + +if __name__ == "__main__": + download_files() diff --git a/utils/introfig.png b/utils/introfig.png new file mode 100644 index 0000000..42db72f Binary files /dev/null and b/utils/introfig.png differ diff --git a/utils/isclinear.drawio.png b/utils/isclinear.drawio.png new file mode 100644 index 0000000..749e2b3 Binary files /dev/null and b/utils/isclinear.drawio.png differ diff --git a/utils/steeredv3.drawio.png b/utils/steeredv3.drawio.png new file mode 100644 index 0000000..3c57c01 Binary files /dev/null and b/utils/steeredv3.drawio.png differ