From f7d4a90300b53db75a0b3e3b751a96641a970574 Mon Sep 17 00:00:00 2001 From: shio <85730998+dino3616@users.noreply.github.com> Date: Wed, 13 Nov 2024 07:41:30 +0900 Subject: [PATCH] =?UTF-8?q?initial:=20=F0=9F=8E=89=20first=20commit?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .devcontainer/devcontainer.json | 20 ++ .dockerignore | 30 ++ .editorconfig | 12 + .env.example | 5 + .github/workflows/app-test.yaml | 19 ++ .gitignore | 33 ++ .python-version | 1 + .vscode/settings.json | 187 +++++++++++ LICENSE | 201 ++++++++++++ README.md | 100 ++++++ docker/Dockerfile.development | 26 ++ docker/docker-compose.development.yaml | 13 + lefthook.yaml | 9 + pyproject.toml | 37 +++ requirements-dev.lock | 264 +++++++++++++++ requirements.lock | 262 +++++++++++++++ ruff.toml | 48 +++ src/__init__.py | 0 src/dataset.py | 27 ++ src/llama/__init__.py | 0 src/llama/build.py | 95 ++++++ src/llama/conf/accelerate/deepspeed.yaml | 13 + src/llama/conf/accelerate/fsdp.yaml | 21 ++ src/llama/conf/deepspeed/zero2.json | 14 + src/llama/conf/variant/base.yaml | 117 +++++++ src/llama/conf/variant/expert.yaml | 117 +++++++ src/llama/conf/variant/merge.yaml | 117 +++++++ src/llama/conf/variant/moe.yaml | 107 ++++++ src/llama/eval.py | 142 ++++++++ src/llama/merge.py | 38 +++ src/llama/processing.py | 397 +++++++++++++++++++++++ src/llama/train.py | 185 +++++++++++ webnavix.code-workspace | 8 + 33 files changed, 2665 insertions(+) create mode 100644 .devcontainer/devcontainer.json create mode 100644 .dockerignore create mode 100644 .editorconfig create mode 100644 .env.example create mode 100644 .github/workflows/app-test.yaml create mode 100644 .gitignore create mode 100644 .python-version create mode 100644 .vscode/settings.json create mode 100644 LICENSE create mode 100644 README.md create mode 100644 docker/Dockerfile.development create mode 100644 docker/docker-compose.development.yaml create mode 100644 lefthook.yaml create mode 100644 pyproject.toml create mode 100644 requirements-dev.lock create mode 100644 requirements.lock create mode 100644 ruff.toml create mode 100644 src/__init__.py create mode 100644 src/dataset.py create mode 100644 src/llama/__init__.py create mode 100644 src/llama/build.py create mode 100644 src/llama/conf/accelerate/deepspeed.yaml create mode 100644 src/llama/conf/accelerate/fsdp.yaml create mode 100644 src/llama/conf/deepspeed/zero2.json create mode 100644 src/llama/conf/variant/base.yaml create mode 100644 src/llama/conf/variant/expert.yaml create mode 100644 src/llama/conf/variant/merge.yaml create mode 100644 src/llama/conf/variant/moe.yaml create mode 100644 src/llama/eval.py create mode 100644 src/llama/merge.py create mode 100644 src/llama/processing.py create mode 100644 src/llama/train.py create mode 100644 webnavix.code-workspace diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..27baf19 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,20 @@ +{ + "name": "webnavix", + "workspaceFolder": "/workspaces/webnavix/", + "dockerComposeFile": ["../docker/docker-compose.development.yaml"], + "service": "app", + "customizations": { + "vscode": { + "extensions": [ + "adam-bender.commit-message-editor", + "charliermarsh.ruff", + "eamodio.gitlens", + "EditorConfig.EditorConfig", + "esbenp.prettier-vscode", + "ms-python.python", + "tamasfe.even-better-toml", + "VisualStudioExptTeam.vscodeintellicode" + ] + } + } +} diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..3f93c3b --- /dev/null +++ b/.dockerignore @@ -0,0 +1,30 @@ +# cache +**/__pycache__/ +**/.cache/ +**/.mypy_cache/ +**/.ruff_cache/ +**/*.egg-info/ + +# dataset +**/wl_data/ + +# debug +**/*log* +**/wandb/ + +# deliverable +**/build/ +**/dist/ +**/out/ +**/results/ + +# dependency +**/.venv/ + +# env file +**/.env* +!**/.env.example + +# misc +**/.DS_Store +**/*.pem diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..4a7ea30 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,12 @@ +root = true + +[*] +indent_style = space +indent_size = 2 +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.md] +trim_trailing_whitespace = false diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..cfe222a --- /dev/null +++ b/.env.example @@ -0,0 +1,5 @@ +COMPOSE_PROJECT_NAME="webnavix" +HF_TOKEN="" +PROJECT_DIR="/workspaces/webnavix/" +WANDB_API_KEY="" +WANDB_PROJECT="" diff --git a/.github/workflows/app-test.yaml b/.github/workflows/app-test.yaml new file mode 100644 index 0000000..96dc403 --- /dev/null +++ b/.github/workflows/app-test.yaml @@ -0,0 +1,19 @@ +name: app test + +on: push + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: checkout + uses: actions/checkout@v4 + + - name: setup rye + uses: eifinger/setup-rye@v4 + + - name: install dependencies + run: rye sync + + - name: check + run: rye check diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6c3c3c9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,33 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# cache +__pycache__/ +.cache/ +.mypy_cache/ +.ruff_cache/ +*.egg-info/ + +# dataset +wl_data + +# debug +*log* +wandb/ + +# deliverable +build/ +dist/ +out/ +results/ +checkpoints + +# dependency +.venv/ + +# env file +.env* +!.env.example + +# misc +.DS_Store +*.pem diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..e4fba21 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..f7e8142 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,187 @@ +{ + "commit-message-editor.tokens": [ + { + "label": "Type", + "name": "type", + "type": "enum", + "description": "Type of changes.", + "combobox": true, + "options": [ + { + "label": "feat: โœจ", + "value": "feat: โœจ", + "description": "Implementation of new features." + }, + { + "label": "feat: ๐ŸŽˆ", + "value": "feat: ๐ŸŽˆ", + "description": "Repair of existing features." + }, + { + "label": "feat: โšฐ๏ธ", + "value": "feat: โšฐ๏ธ", + "description": "Deletion of features." + }, + { + "label": "fix: ๐Ÿ›", + "value": "fix: ๐Ÿ›", + "description": "Bug fixes." + }, + { + "label": "fix: ๐Ÿš‘๏ธ", + "value": "fix: ๐Ÿš‘๏ธ", + "description": "Critical bug fixes or major changes." + }, + { + "label": "doc: ๐Ÿ“", + "value": "doc: ๐Ÿ“", + "description": "Documentation changes." + }, + { + "label": "typo: ๐Ÿ–‹๏ธ", + "value": "typo: ๐Ÿ–‹๏ธ", + "description": "Typography changes." + }, + { + "label": "style: ๐Ÿ’„", + "value": "style: ๐Ÿ’„", + "description": "Style changes." + }, + { + "label": "refactor: โ™ป๏ธ", + "value": "refactor: โ™ป๏ธ", + "description": "Code formatting or refactoring." + }, + { + "label": "test: ๐Ÿงช", + "value": "test: ๐Ÿงช", + "description": "Test cases changes." + }, + { + "label": "ci: ๐Ÿฆบ", + "value": "ci: ๐Ÿฆบ", + "description": "CI changes." + }, + { + "label": "build: ๐Ÿ“ฆ๏ธ", + "value": "build: ๐Ÿ“ฆ๏ธ", + "description": "Build system or dependency changes." + }, + { + "label": "container: ๐Ÿณ", + "value": "container: ๐Ÿณ", + "description": "The Dockerfile changes." + }, + { + "label": "container: ๐Ÿ™", + "value": "container: ๐Ÿ™", + "description": "The docker-compose changes." + }, + { + "label": "chore: ๐Ÿ”ง", + "value": "chore: ๐Ÿ”ง", + "description": "Configuration changes." + }, + { + "label": "chore: ๐Ÿ”จ", + "value": "chore: ๐Ÿ”จ", + "description": "Development script changes." + }, + { + "label": "chore: ๐Ÿฑ", + "value": "chore: ๐Ÿฑ", + "description": "Assets changes." + }, + { + "label": "revert: โช๏ธ", + "value": "revert: โช๏ธ", + "description": "Reversion of changes." + }, + { + "label": "wip: ๐Ÿšง", + "value": "wip: ๐Ÿšง", + "description": "Changes that will be squashed." + }, + { + "label": "initial: ๐ŸŽ‰", + "value": "initial: ๐ŸŽ‰", + "description": "The first commit." + } + ] + }, + { + "label": "Scope", + "name": "scope", + "type": "text", + "description": "Scope of changes.", + "prefix": " (", + "suffix": ")" + }, + { + "label": "Short Description", + "name": "description", + "type": "text", + "description": "Commit summary.", + "prefix": " " + }, + { + "label": "Body", + "name": "body", + "type": "text", + "description": "Detailed description of commit.", + "maxLines": 10, + "multiline": true, + "lines": 5 + }, + { + "label": "Footer", + "name": "footer", + "description": "Description of disruptive changes or signature.", + "type": "text", + "multiline": true + } + ], + "commit-message-editor.dynamicTemplate": [ + "{type}{scope}{description}", + "", + "{body}", + "", + "{footer}" + ], + "commit-message-editor.staticTemplate": [ + "label: emoji (scope) short-description", + "", + "body", + "", + "footer" + ], + "commit-message-editor.view.defaultView": "form", + "editor.defaultFormatter": "esbenp.prettier-vscode", + "files.encoding": "utf8", + "files.eol": "\n", + "python.analysis.exclude": [ + "**/__pycache__/", + "**/.cache/", + "**/.mypy_cache/", + "**/.ruff_cache/", + "**/.venv/", + "**/*.egg-info/", + "**/build/", + "**/checkpoints/", + "**/dist/", + "**/out/", + "**/results/", + "**/wandb/", + "**/wl_data/", + "**/*log/*" + ], + "python.analysis.typeCheckingMode": "basic", + "python.defaultInterpreterPath": "/opt/rye/shims/python", + "[python]": { + "editor.codeActionsOnSave": { + "source.fixAll.ruff": "explicit", + "source.organizeImports.ruff": "explicit" + }, + "editor.defaultFormatter": "charliermarsh.ruff" + } +} diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..8edc12a --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2024 NITIC-NLP-Team + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..3ef83e5 --- /dev/null +++ b/README.md @@ -0,0 +1,100 @@ +# WebNavix: Continuous Generalist Web Navigation Agent using Domain-wise Mixture-of-Experts + +WebNavix is a continuous generalist web navigation agent that merges individually fine-tuned LLMs as domain experts. + +## Core Contributors ๐Ÿ› ๏ธ + +| shio | ituki | +| :--------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------: | +| [](https://github.com/dino3616) | [](https://github.com/ituki0426) | +| `#repository-owner` `#main-author` `#model-composer` | `#co-author` `#model-analyst` | + +## Setup with Dev Containers ๐Ÿ“ฆ + +You can easily launch the development environment of WebNavix with Dev Containers. +Here is the step-by-step guide. + +### Attention + +- You need to install [Docker](https://docs.docker.com/get-docker) and [VSCode](https://code.visualstudio.com) before. + +### 1. clone git repository + +```bash +git clone "https://github.com/nitic-nlp-team/webnavix" && cd "./webnavix" +``` + +### 2. set environment variables + +See `.env.example` or contact the [repository owner](https://github.com/dino3616) for more details. + +### 3. launch dev containers + +Launch containers using the VSCode extension [Dev Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers). + +### 4. pin python version + +```bash +rye pin $(cat "./.python-version") +``` + +### 5. install dependencies + +```bash +rye sync +``` + +### 6. activate virtual environment + +```bash +source "./.venv/bin/activate" +``` + +### 7. install FlashAttention-2 + +```bash +uv pip install flash-attn --no-build-isolation +``` + +## Setup locally ๐Ÿ–ฅ๏ธ + +If you want to build an environment more quickly without Docker, you can follow these steps to build your environment locally. + +### Attention + +- You need to install [rye](https://rye.astral.sh/guide/installation) before. +- [Optional] You should install project recommended VSCode extensions that specified in [`.devcontainer/devcontainer.json`](./.devcontainer/devcontainer.json#L8C7-L17C8) before. + +### 1. clone git repository + +```bash +git clone "https://github.com/nitic-nlp-team/webnavix" && cd "./webnavix" +``` + +### 2. set environment variables + +See `.env.example` or contact the [repository owner](https://github.com/dino3616) for more details. + +### 3. pin python version + +```bash +rye pin $(cat "./.python-version") +``` + +### 4. install dependencies + +```bash +rye sync +``` + +### 5. activate virtual environment + +```bash +source "./.venv/bin/activate" +``` + +### 6. install FlashAttention-2 + +```bash +uv pip install flash-attn --no-build-isolation +``` diff --git a/docker/Dockerfile.development b/docker/Dockerfile.development new file mode 100644 index 0000000..4c54fe1 --- /dev/null +++ b/docker/Dockerfile.development @@ -0,0 +1,26 @@ +FROM python:3.12 + +SHELL ["/bin/bash", "-o", "pipefail", "-c"] + +# hadolint ignore=DL3008 +RUN apt-get update \ + && apt-get --no-install-recommends -y install git gnupg2 ca-certificates curl pipx \ + && pipx ensurepath \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists + +RUN curl -sSf https://rye.astral.sh/get | RYE_INSTALL_OPTION="--yes" bash \ + && echo "source '$HOME/.rye/env'" >> ~/.bashrc \ + && /root/.rye/shims/rye config --set-bool behavior.global-python=true \ + && /root/.rye/shims/rye config --set-bool behavior.use-uv=true + +RUN RYE_UV_HOME=$(find "$HOME/.rye/uv" -type d -regex '.*/[0-9]+\.[0-9]+\.[0-9]+$') \ + && echo "export PATH=\"$RYE_UV_HOME:\$PATH\"" >> ~/.bashrc + +WORKDIR /workspaces/webnavix/ + +COPY ./.python-version ./pyproject.toml ./requirements* ./ +# hadolint ignore=SC1091 +RUN "$HOME/.rye/shims/rye" pin "$(cat ./.python-version)" && "$HOME/.rye/shims/rye" sync && source ./.venv/bin/activate + +COPY ./ ./ diff --git a/docker/docker-compose.development.yaml b/docker/docker-compose.development.yaml new file mode 100644 index 0000000..4d2eb25 --- /dev/null +++ b/docker/docker-compose.development.yaml @@ -0,0 +1,13 @@ +services: + app: + container_name: webnavix + build: + context: ../ + dockerfile: ./docker/Dockerfile.development + volumes: + - type: bind + source: ../ + target: /workspaces/webnavix/ + environment: + PROJECT_DIR: /workspaces/webnavix/ + tty: true diff --git a/lefthook.yaml b/lefthook.yaml new file mode 100644 index 0000000..9de67fa --- /dev/null +++ b/lefthook.yaml @@ -0,0 +1,9 @@ +pre-commit: + parallel: true + commands: + check-py: + glob: "*.*{py}*" + run: ruff check --fix {staged_files} + format-py: + glob: "*.*{py}*" + run: ruff format --fix {staged_files} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..956d8a0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,37 @@ +[project] +name = "webnavix" +version = "0.1.0" +description = "Continuous Generalist Web Navigation Agent using domain-wise Mixture-of-Experts that merges individually fine-tuned LLMs as domain experts." +authors = [ + { name = "shio", email = "85730998+dino3616@users.noreply.github.com" }, +] +dependencies = [ + "accelerate==0.27.2", + "bitsandbytes==0.42.0", + "datasets==2.19.2", + "deepspeed==0.15.1", + "huggingface-hub==0.24.6", + "hydra-core==1.3.2", + "lxml==5.3.0", + "mergoo==0.0.10", + "omegaconf==2.3.0", + "peft==0.12.0", + "python-dotenv==1.0.1", + "torch==2.4.1", + "tqdm==4.66.2", + "transformers==4.42.4", + "trl==0.10.1", + "wandb==0.17.9", + "weblinx[eval]==0.3.0", +] +readme = "README.md" +requires-python = "~=3.12" + +[tool.rye] +managed = true +dev-dependencies = ["lefthook==0.1.2", "ruff==0.6.4"] + +[tool.rye.scripts] +check = { chain = ["lint", "fmt"] } +lint = "ruff check ./ --diff" +fmt = "ruff fmt ./" diff --git a/requirements-dev.lock b/requirements-dev.lock new file mode 100644 index 0000000..3b552f2 --- /dev/null +++ b/requirements-dev.lock @@ -0,0 +1,264 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false +# generate-hashes: false +# universal: false + +-e file:. +accelerate==0.27.2 + # via mergoo + # via peft + # via trl + # via webnavix +aiohttp==3.9.5 + # via datasets + # via fsspec +aiosignal==1.3.1 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +antlr4-python3-runtime==4.9.3 + # via hydra-core + # via omegaconf +attrs==23.2.0 + # via aiohttp +bitsandbytes==0.42.0 + # via webnavix +certifi==2024.7.4 + # via requests + # via sentry-sdk +charset-normalizer==3.3.2 + # via requests +click==8.1.7 + # via wandb +colorama==0.4.6 + # via sacrebleu +datasets==2.19.2 + # via trl + # via webnavix +deepspeed==0.15.1 + # via webnavix +dill==0.3.7 + # via datasets + # via multiprocess +docker-pycreds==0.4.0 + # via wandb +docstring-parser==0.16 + # via tyro +filelock==3.15.4 + # via datasets + # via huggingface-hub + # via torch + # via transformers +frozenlist==1.4.1 + # via aiohttp + # via aiosignal +fsspec==2023.10.0 + # via datasets + # via huggingface-hub + # via torch +gitdb==4.0.11 + # via gitpython +gitpython==3.1.43 + # via wandb +hjson==3.1.0 + # via deepspeed +huggingface-hub==0.24.6 + # via accelerate + # via datasets + # via mergoo + # via peft + # via tokenizers + # via transformers + # via webnavix +hydra-core==1.3.2 + # via webnavix +idna==3.7 + # via requests + # via yarl +jinja2==3.1.4 + # via torch +lefthook==0.1.2 +lxml==5.3.0 + # via sacrebleu + # via webnavix +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.5 + # via jinja2 +mdurl==0.1.2 + # via markdown-it-py +mergoo==0.0.10 + # via webnavix +mpmath==1.3.0 + # via sympy +multidict==6.0.5 + # via aiohttp + # via yarl +multiprocess==0.70.15 + # via datasets +networkx==3.3 + # via torch +ninja==1.11.1.1 + # via deepspeed +numpy==1.26.4 + # via accelerate + # via datasets + # via deepspeed + # via mergoo + # via pandas + # via peft + # via pyarrow + # via sacrebleu + # via scipy + # via transformers + # via trl + # via weblinx +omegaconf==2.3.0 + # via hydra-core + # via webnavix +packaging==24.1 + # via accelerate + # via datasets + # via deepspeed + # via huggingface-hub + # via hydra-core + # via peft + # via transformers +pandas==2.2.2 + # via datasets + # via weblinx +peft==0.12.0 + # via mergoo + # via webnavix +platformdirs==4.2.2 + # via wandb +portalocker==2.10.1 + # via sacrebleu +protobuf==5.27.2 + # via mergoo + # via wandb +psutil==6.0.0 + # via accelerate + # via deepspeed + # via peft + # via wandb +py-cpuinfo==9.0.0 + # via deepspeed +pyarrow==17.0.0 + # via datasets +pyarrow-hotfix==0.6 + # via datasets +pydantic==2.9.2 + # via deepspeed +pydantic-core==2.23.4 + # via pydantic +pygments==2.18.0 + # via rich +python-dateutil==2.9.0.post0 + # via pandas +python-dotenv==1.0.1 + # via webnavix +pytz==2024.1 + # via pandas +pyyaml==6.0.1 + # via accelerate + # via datasets + # via huggingface-hub + # via omegaconf + # via peft + # via transformers + # via wandb +regex==2024.5.15 + # via sacrebleu + # via transformers +requests==2.32.3 + # via datasets + # via fsspec + # via huggingface-hub + # via transformers + # via wandb +rich==13.7.1 + # via tyro +ruff==0.6.4 +sacrebleu==2.4.3 + # via weblinx +safetensors==0.4.3 + # via accelerate + # via mergoo + # via peft + # via transformers +scipy==1.14.0 + # via bitsandbytes +sentencepiece==0.2.0 + # via mergoo +sentry-sdk==2.13.0 + # via wandb +setproctitle==1.3.3 + # via wandb +setuptools==71.1.0 + # via torch + # via wandb +shtab==1.7.1 + # via tyro +six==1.16.0 + # via docker-pycreds + # via python-dateutil +smmap==5.0.1 + # via gitdb +sympy==1.13.1 + # via torch +tabulate==0.9.0 + # via sacrebleu +tokenizers==0.19.1 + # via transformers +torch==2.4.1 + # via accelerate + # via deepspeed + # via mergoo + # via peft + # via trl + # via webnavix +tqdm==4.66.2 + # via datasets + # via deepspeed + # via huggingface-hub + # via mergoo + # via peft + # via transformers + # via weblinx + # via webnavix +transformers==4.42.4 + # via mergoo + # via peft + # via trl + # via webnavix +trl==0.10.1 + # via webnavix +typing-extensions==4.12.2 + # via huggingface-hub + # via mergoo + # via pydantic + # via pydantic-core + # via torch + # via tyro +tyro==0.8.5 + # via trl +tzdata==2024.1 + # via pandas +urllib3==2.2.2 + # via requests + # via sentry-sdk +wandb==0.17.9 + # via webnavix +weblinx==0.3.0 + # via webnavix +xxhash==3.4.1 + # via datasets +yarl==1.9.4 + # via aiohttp diff --git a/requirements.lock b/requirements.lock new file mode 100644 index 0000000..4f7883f --- /dev/null +++ b/requirements.lock @@ -0,0 +1,262 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false +# generate-hashes: false +# universal: false + +-e file:. +accelerate==0.27.2 + # via mergoo + # via peft + # via trl + # via webnavix +aiohttp==3.9.5 + # via datasets + # via fsspec +aiosignal==1.3.1 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +antlr4-python3-runtime==4.9.3 + # via hydra-core + # via omegaconf +attrs==23.2.0 + # via aiohttp +bitsandbytes==0.42.0 + # via webnavix +certifi==2024.7.4 + # via requests + # via sentry-sdk +charset-normalizer==3.3.2 + # via requests +click==8.1.7 + # via wandb +colorama==0.4.6 + # via sacrebleu +datasets==2.19.2 + # via trl + # via webnavix +deepspeed==0.15.1 + # via webnavix +dill==0.3.7 + # via datasets + # via multiprocess +docker-pycreds==0.4.0 + # via wandb +docstring-parser==0.16 + # via tyro +filelock==3.15.4 + # via datasets + # via huggingface-hub + # via torch + # via transformers +frozenlist==1.4.1 + # via aiohttp + # via aiosignal +fsspec==2023.10.0 + # via datasets + # via huggingface-hub + # via torch +gitdb==4.0.11 + # via gitpython +gitpython==3.1.43 + # via wandb +hjson==3.1.0 + # via deepspeed +huggingface-hub==0.24.6 + # via accelerate + # via datasets + # via mergoo + # via peft + # via tokenizers + # via transformers + # via webnavix +hydra-core==1.3.2 + # via webnavix +idna==3.7 + # via requests + # via yarl +jinja2==3.1.4 + # via torch +lxml==5.3.0 + # via sacrebleu + # via webnavix +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.5 + # via jinja2 +mdurl==0.1.2 + # via markdown-it-py +mergoo==0.0.10 + # via webnavix +mpmath==1.3.0 + # via sympy +multidict==6.0.5 + # via aiohttp + # via yarl +multiprocess==0.70.15 + # via datasets +networkx==3.3 + # via torch +ninja==1.11.1.1 + # via deepspeed +numpy==1.26.4 + # via accelerate + # via datasets + # via deepspeed + # via mergoo + # via pandas + # via peft + # via pyarrow + # via sacrebleu + # via scipy + # via transformers + # via trl + # via weblinx +omegaconf==2.3.0 + # via hydra-core + # via webnavix +packaging==24.1 + # via accelerate + # via datasets + # via deepspeed + # via huggingface-hub + # via hydra-core + # via peft + # via transformers +pandas==2.2.2 + # via datasets + # via weblinx +peft==0.12.0 + # via mergoo + # via webnavix +platformdirs==4.2.2 + # via wandb +portalocker==2.10.1 + # via sacrebleu +protobuf==5.27.2 + # via mergoo + # via wandb +psutil==6.0.0 + # via accelerate + # via deepspeed + # via peft + # via wandb +py-cpuinfo==9.0.0 + # via deepspeed +pyarrow==17.0.0 + # via datasets +pyarrow-hotfix==0.6 + # via datasets +pydantic==2.9.2 + # via deepspeed +pydantic-core==2.23.4 + # via pydantic +pygments==2.18.0 + # via rich +python-dateutil==2.9.0.post0 + # via pandas +python-dotenv==1.0.1 + # via webnavix +pytz==2024.1 + # via pandas +pyyaml==6.0.1 + # via accelerate + # via datasets + # via huggingface-hub + # via omegaconf + # via peft + # via transformers + # via wandb +regex==2024.5.15 + # via sacrebleu + # via transformers +requests==2.32.3 + # via datasets + # via fsspec + # via huggingface-hub + # via transformers + # via wandb +rich==13.7.1 + # via tyro +sacrebleu==2.4.3 + # via weblinx +safetensors==0.4.3 + # via accelerate + # via mergoo + # via peft + # via transformers +scipy==1.14.0 + # via bitsandbytes +sentencepiece==0.2.0 + # via mergoo +sentry-sdk==2.13.0 + # via wandb +setproctitle==1.3.3 + # via wandb +setuptools==71.1.0 + # via torch + # via wandb +shtab==1.7.1 + # via tyro +six==1.16.0 + # via docker-pycreds + # via python-dateutil +smmap==5.0.1 + # via gitdb +sympy==1.13.1 + # via torch +tabulate==0.9.0 + # via sacrebleu +tokenizers==0.19.1 + # via transformers +torch==2.4.1 + # via accelerate + # via deepspeed + # via mergoo + # via peft + # via trl + # via webnavix +tqdm==4.66.2 + # via datasets + # via deepspeed + # via huggingface-hub + # via mergoo + # via peft + # via transformers + # via weblinx + # via webnavix +transformers==4.42.4 + # via mergoo + # via peft + # via trl + # via webnavix +trl==0.10.1 + # via webnavix +typing-extensions==4.12.2 + # via huggingface-hub + # via mergoo + # via pydantic + # via pydantic-core + # via torch + # via tyro +tyro==0.8.5 + # via trl +tzdata==2024.1 + # via pandas +urllib3==2.2.2 + # via requests + # via sentry-sdk +wandb==0.17.9 + # via webnavix +weblinx==0.3.0 + # via webnavix +xxhash==3.4.1 + # via datasets +yarl==1.9.4 + # via aiohttp diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..10c34d9 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,48 @@ +target-version = "py312" +indent-width = 4 +line-length = 120 +exclude = [ + "__pycache__", + ".cache", + ".mypy_cache", + ".ruff_cache", + ".venv", + "*.egg-info", + "build", + "checkpoints", + "dist", + "out", + "results", + "wandb", + "wl_data", + "*log*", +] + +[format] +indent-style = "space" +line-ending = "auto" +quote-style = "double" +skip-magic-trailing-comma = false + + +[lint] +select = ["ALL"] +fixable = ["ALL"] +ignore = [ + "ANN001", + "EM101", + "ERA001", + "FBT001", + "FBT002", + "RET504", + "TRY002", + "TRY003", +] + +[lint.extend-per-file-ignores] +"__init__.py" = ["D1", "F403"] +"build.py" = ["D1"] +"dataset.py" = ["D1"] +"eval.py" = ["D1"] +"merge.py" = ["D1"] +"train.py" = ["D1"] diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/dataset.py b/src/dataset.py new file mode 100644 index 0000000..10e47c8 --- /dev/null +++ b/src/dataset.py @@ -0,0 +1,27 @@ +import time + +from huggingface_hub import snapshot_download + +attempt = 1 +while True: + try: + print(f"Attempt {attempt} for dataset 'McGill-NLP/WebLINX-full'.") # noqa: T201 + + snapshot_download( + repo_id="McGill-NLP/WebLINX-full", + repo_type="dataset", + local_dir="./wl_data", + ignore_patterns=["**/video.mp4"], + resume_download=True, + max_workers=16, + ) + + print("Download successful!") # noqa: T201 + break + + except Exception as e: # noqa: BLE001 + print(f"Error occurred: {e}") # noqa: T201 + + print("Retrying in 5 seconds...") # noqa: T201 + time.sleep(5) + attempt += 1 diff --git a/src/llama/__init__.py b/src/llama/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/llama/build.py b/src/llama/build.py new file mode 100644 index 0000000..d1bcd2e --- /dev/null +++ b/src/llama/build.py @@ -0,0 +1,95 @@ +import json +import logging +from functools import partial +from pathlib import Path +from typing import Any + +import huggingface_hub +import hydra +import wandb +from dotenv import load_dotenv +from omegaconf import OmegaConf +from transformers import AutoTokenizer +from weblinx import Demonstration +from weblinx.processing import load_candidate_elements +from weblinx.processing.prompt import ( + build_input_records_from_selected_turns, + select_turns_and_candidates_for_prompts, +) +from weblinx.utils import ( + load_demo_names_in_split, + set_seed, +) + +from .processing import ( + build_formatter_for_multichoice, + build_prompt_records_for_llama_truncated, + insert_formatted_chat_into_records, +) + +load_dotenv() + + +@hydra.main(config_path="conf", config_name="config", version_base=None) +def main(cfg) -> None: + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info(OmegaConf.to_yaml(cfg)) + + huggingface_hub.login(token=cfg.huggingface.token) + wandb.login(key=cfg.wandb.key) + + set_seed(cfg.seed) + + model_save_dir = Path(cfg.model.save_dir).expanduser() + model_save_dir.mkdir(exist_ok=True, parents=True) + input_records_save_dir = Path(model_save_dir.joinpath(f"{cfg.build.split}")).expanduser() + input_records_save_dir.mkdir(exist_ok=True, parents=True) + + tokenizer = AutoTokenizer.from_pretrained( + cfg.model.base_name, + padding_side="right", + trust_remote_code=True, + ) + tokenizer.pad_token = tokenizer.unk_token + + split_path = Path(cfg.data.split_path).expanduser() + + demo_names: list[str] = load_demo_names_in_split(split_path, split=cfg.build.split) + demos = [Demonstration(demo_name, base_dir=cfg.data.base_dir) for demo_name in demo_names] + candidates = load_candidate_elements(path=cfg.candidates.build_path) + + format_intent = build_formatter_for_multichoice() + build_prompt_records_fn = partial( + build_prompt_records_for_llama_truncated, + format_intent=format_intent, + tokenizer=tokenizer, # type: ignore # noqa: PGH003 + ) + + selected_turns: list[dict[str, Any]] = select_turns_and_candidates_for_prompts( + demos=demos, + candidates=candidates, + num_candidates=cfg.candidates.k, + ) + + input_records = build_input_records_from_selected_turns( + selected_turns=selected_turns, + format_intent=format_intent, + build_prompt_records_fn=build_prompt_records_fn, + format_prompt_records_fn=None, + ) + + input_records = insert_formatted_chat_into_records( + input_records, + demos, + tokenizer, + include_output_target=cfg.build.include_output_target, + ) + + with Path.open(input_records_save_dir.joinpath("input_records.json"), "w") as f: + json.dump(input_records, f, indent=2) + + +if __name__ == "__main__": + main() diff --git a/src/llama/conf/accelerate/deepspeed.yaml b/src/llama/conf/accelerate/deepspeed.yaml new file mode 100644 index 0000000..c4882ae --- /dev/null +++ b/src/llama/conf/accelerate/deepspeed.yaml @@ -0,0 +1,13 @@ +compute_environment: LOCAL_MACHINE +distributed_type: DEEPSPEED +deepspeed_config: + deepspeed_config_file: ../deepspeed/zero2.json + # deepspeed_moe_layer_cls_names: MoeLayer +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +use_cpu: false diff --git a/src/llama/conf/accelerate/fsdp.yaml b/src/llama/conf/accelerate/fsdp.yaml new file mode 100644 index 0000000..743503e --- /dev/null +++ b/src/llama/conf/accelerate/fsdp.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +distributed_type: FSDP +fsdp_config: + fsdp_sharding_strategy: FULL_SHARD + fsdp_offload_params: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_forward_prefetch: false + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_sync_module_states: true +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +use_cpu: false diff --git a/src/llama/conf/deepspeed/zero2.json b/src/llama/conf/deepspeed/zero2.json new file mode 100644 index 0000000..5f0ac20 --- /dev/null +++ b/src/llama/conf/deepspeed/zero2.json @@ -0,0 +1,14 @@ +{ + "train_batch_size": "auto", + "gradient_accumulation_steps": "auto", + "optimizer": "auto", + "scheduler": "auto", + "bf16": "auto", + "zero_optimization": { + "stage": 2, + "contiguous_gradients": true, + "offload_optimizer": { + "device": "cpu" + } + } +} diff --git a/src/llama/conf/variant/base.yaml b/src/llama/conf/variant/base.yaml new file mode 100644 index 0000000..952edc1 --- /dev/null +++ b/src/llama/conf/variant/base.yaml @@ -0,0 +1,117 @@ +project_name: webnavix +project_dir: ${oc.env:PROJECT_DIR} +seed: 123 + +data: + num_proc: 8 + split_path: ${project_dir}/wl_data/splits.json + base_dir: ${project_dir}/wl_data/demonstrations/ + +candidates: + k: 10 + build_path: ${project_dir}/wl_data/candidates/${build.split}.jsonl + train_path: ${project_dir}/wl_data/candidates/${train.split}.jsonl + eval_path: ${project_dir}/wl_data/candidates/${eval.split}.jsonl + +model: + name: nitic-nlp-team/webnavix-llama-base + base_name: princeton-nlp/Sheared-LLaMA-2.7B + save_dir: ${project_dir}/checkpoints/${project_name}/${model.name} + max_inp_len: null + max_out_len: 256 + use_rope: True + use_flash_attention_2: True + moe: False + freeze: + use: False + trainable_layers: + - gate_proj + - up_proj + - down_proj + +build: + split: train + include_output_target: True + +train: + split: train + domain: False + num_epochs: 3 + learning_rate: 5e-5 + batch_size_per_device: 16 + gradient_accumulation_steps: 1 + gradient_checkpointing: True + max_grad_norm: 1.0 + optim: adamw_torch + weight_decay: 0.0 + scheduler: linear + warmup_steps: 0 + warmup_ratio: 0.0 + accelerate: + use: False + qlora: + use: False + r: 256 + alpha: 256 + dropout: 0.05 + bias: none + target_modules: + - embed_tokens + - q_proj + - k_proj + - v_proj + - o_proj + - gate_proj + - up_proj + - down_proj + - lm_head + +merge: + num_experts_per_tok: 2 + experts: + - expert_name: ai-tools-expert + model_id: nitic-nlp-team/webnavix-llama-ai-tools + - expert_name: booking-expert + model_id: nitic-nlp-team/webnavix-llama-booking + - expert_name: composing-expert + model_id: nitic-nlp-team/webnavix-llama-composing + - expert_name: information-lookup-expert + model_id: nitic-nlp-team/webnavix-llama-information-lookup + - expert_name: shopping-expert + model_id: nitic-nlp-team/webnavix-llama-shopping + - expert_name: social-interaction-expert + model_id: nitic-nlp-team/webnavix-llama-social-interaction + - expert_name: summarizing-expert + model_id: nitic-nlp-team/webnavix-llama-summarizing + - expert_name: task-management-expert + model_id: nitic-nlp-team/webnavix-llama-task-management + - expert_name: shared-expert + model_id: nitic-nlp-team/webnavix-llama-shared + router_layers: + - gate_proj + - up_proj + - down_proj + +eval: + split: valid + domain: False + batch_size_per_device: 16 + gradient_accumulation_steps: 1 + result_dir: ${project_dir}/results/${project_name}/${eval.split}/${model.name} + load_from_save_dir: True + +huggingface: + token: ${oc.env:HF_TOKEN} + +wandb: + project: ${oc.env:WANDB_PROJECT} + key: ${oc.env:WANDB_API_KEY} + +hydra: + run: + dir: ${project_dir}/logs/${project_name}/${hydra.job.name}/${now:%Y-%m-%d-%H:%M:%S} + sweep: + dir: ${hydra.run.dir} + job: + chdir: False + verbose: INFO diff --git a/src/llama/conf/variant/expert.yaml b/src/llama/conf/variant/expert.yaml new file mode 100644 index 0000000..5b68de5 --- /dev/null +++ b/src/llama/conf/variant/expert.yaml @@ -0,0 +1,117 @@ +project_name: webnavix +project_dir: ${oc.env:PROJECT_DIR} +seed: 123 + +data: + num_proc: 8 + split_path: ${project_dir}/wl_data/splits.json + base_dir: ${project_dir}/wl_data/demonstrations/ + +candidates: + k: 10 + build_path: ${project_dir}/wl_data/candidates/${build.split}.jsonl + train_path: ${project_dir}/wl_data/candidates/${train.split}.jsonl + eval_path: ${project_dir}/wl_data/candidates/${eval.split}.jsonl + +model: + name: nitic-nlp-team/webnavix-ai-tools + base_name: nitic-nlp-team/webnavix-llama-base + save_dir: ${project_dir}/checkpoints/${project_name}/${model.name} + max_inp_len: null + max_out_len: 256 + use_rope: True + use_flash_attention_2: True + moe: False + freeze: + use: True + trainable_layers: + - gate_proj + - up_proj + - down_proj + +build: + split: train + include_output_target: True + +train: + split: train + domain: AI_Tools + num_epochs: 3 + learning_rate: 5e-5 + batch_size_per_device: 16 + gradient_accumulation_steps: 1 + gradient_checkpointing: True + max_grad_norm: 1.0 + optim: adamw_torch + weight_decay: 0.0 + scheduler: linear + warmup_steps: 0 + warmup_ratio: 0.0 + accelerate: + use: False + qlora: + use: False + r: 256 + alpha: 256 + dropout: 0.05 + bias: none + target_modules: + - embed_tokens + - q_proj + - k_proj + - v_proj + - o_proj + - gate_proj + - up_proj + - down_proj + - lm_head + +merge: + num_experts_per_tok: 2 + experts: + - expert_name: ai-tools-expert + model_id: nitic-nlp-team/webnavix-llama-ai-tools + - expert_name: booking-expert + model_id: nitic-nlp-team/webnavix-llama-booking + - expert_name: composing-expert + model_id: nitic-nlp-team/webnavix-llama-composing + - expert_name: information-lookup-expert + model_id: nitic-nlp-team/webnavix-llama-information-lookup + - expert_name: shopping-expert + model_id: nitic-nlp-team/webnavix-llama-shopping + - expert_name: social-interaction-expert + model_id: nitic-nlp-team/webnavix-llama-social-interaction + - expert_name: summarizing-expert + model_id: nitic-nlp-team/webnavix-llama-summarizing + - expert_name: task-management-expert + model_id: nitic-nlp-team/webnavix-llama-task-management + - expert_name: shared-expert + model_id: nitic-nlp-team/webnavix-llama-shared + router_layers: + - gate_proj + - up_proj + - down_proj + +eval: + split: valid + domain: AI_Tools + batch_size_per_device: 16 + gradient_accumulation_steps: 1 + result_dir: ${project_dir}/results/${project_name}/${eval.split}/${model.name} + load_from_save_dir: True + +huggingface: + token: ${oc.env:HF_TOKEN} + +wandb: + project: ${oc.env:WANDB_PROJECT} + key: ${oc.env:WANDB_API_KEY} + +hydra: + run: + dir: ${project_dir}/logs/${project_name}/${hydra.job.name}/${now:%Y-%m-%d-%H:%M:%S} + sweep: + dir: ${hydra.run.dir} + job: + chdir: False + verbose: INFO diff --git a/src/llama/conf/variant/merge.yaml b/src/llama/conf/variant/merge.yaml new file mode 100644 index 0000000..23f0927 --- /dev/null +++ b/src/llama/conf/variant/merge.yaml @@ -0,0 +1,117 @@ +project_name: webnavix +project_dir: ${oc.env:PROJECT_DIR} +seed: 123 + +data: + num_proc: 8 + split_path: ${project_dir}/wl_data/splits.json + base_dir: ${project_dir}/wl_data/demonstrations/ + +candidates: + k: 10 + build_path: ${project_dir}/wl_data/candidates/${build.split}.jsonl + train_path: ${project_dir}/wl_data/candidates/${train.split}.jsonl + eval_path: ${project_dir}/wl_data/candidates/${eval.split}.jsonl + +model: + name: nitic-nlp-team/webnavix-llama-merged + base_name: nitic-nlp-team/webnavix-llama-base + save_dir: ${project_dir}/checkpoints/${project_name}/${model.name} + max_inp_len: null + max_out_len: 256 + use_rope: True + use_flash_attention_2: True + moe: True + freeze: + use: False + trainable_layers: + - gate_proj + - up_proj + - down_proj + +build: + split: train + include_output_target: True + +train: + split: train + domain: False + num_epochs: 3 + learning_rate: 5e-5 + batch_size_per_device: 16 + gradient_accumulation_steps: 1 + gradient_checkpointing: True + max_grad_norm: 1.0 + optim: adamw_torch + weight_decay: 0.0 + scheduler: linear + warmup_steps: 0 + warmup_ratio: 0.0 + accelerate: + use: False + qlora: + use: False + r: 256 + alpha: 256 + dropout: 0.05 + bias: none + target_modules: + - embed_tokens + - q_proj + - k_proj + - v_proj + - o_proj + - gate_proj + - up_proj + - down_proj + - lm_head + +merge: + num_experts_per_tok: 2 + experts: + - expert_name: ai-tools-expert + model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-ai-tools/checkpoint-500/ + - expert_name: booking-expert + model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-booking/checkpoint-1050/ + - expert_name: composing-expert + model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-composing/checkpoint-500/ + - expert_name: information-lookup-expert + model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-information-lookup/checkpoint-500/ + - expert_name: shopping-expert + model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-shopping/checkpoint-400/ + - expert_name: social-interaction-expert + model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-social-interaction/checkpoint-150/ + - expert_name: summarizing-expert + model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-summarizing/checkpoint-450/ + - expert_name: task-management-expert + model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-task-management/checkpoint-450/ + - expert_name: shared-expert + model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-shared/checkpoint-1800/ + router_layers: + - gate_proj + - up_proj + - down_proj + +eval: + split: valid + domain: False + batch_size_per_device: 16 + gradient_accumulation_steps: 1 + result_dir: ${project_dir}/results/${project_name}/${eval.split}/${model.name} + load_from_save_dir: True + +huggingface: + token: ${oc.env:HF_TOKEN} + +wandb: + project: ${oc.env:WANDB_PROJECT} + key: ${oc.env:WANDB_API_KEY} + +hydra: + run: + dir: ${project_dir}/logs/${project_name}/${hydra.job.name}/${now:%Y-%m-%d-%H:%M:%S} + sweep: + dir: ${hydra.run.dir} + job: + chdir: False + verbose: INFO diff --git a/src/llama/conf/variant/moe.yaml b/src/llama/conf/variant/moe.yaml new file mode 100644 index 0000000..aab8b00 --- /dev/null +++ b/src/llama/conf/variant/moe.yaml @@ -0,0 +1,107 @@ +project_name: webnavix +project_dir: ${oc.env:PROJECT_DIR} +seed: 123 + +data: + num_proc: 8 + split_path: ${project_dir}/wl_data/splits.json + base_dir: ${project_dir}/wl_data/demonstrations/ + +candidates: + k: 10 + build_path: ${project_dir}/wl_data/candidates/${build.split}.jsonl + train_path: ${project_dir}/wl_data/candidates/${train.split}.jsonl + eval_path: ${project_dir}/wl_data/candidates/${eval.split}.jsonl + +model: + name: nitic-nlp-team/webnavix-llama + base_name: nitic-nlp-team/webnavix-llama-merged + save_dir: ${project_dir}/checkpoints/${project_name}/${model.name} + max_inp_len: null + max_out_len: 256 + use_rope: True + use_flash_attention_2: True + moe: True + freeze: + use: False + trainable_layers: + - gate + +build: + split: train + include_output_target: True + +train: + split: train + domain: False + num_epochs: 3 + learning_rate: 5e-5 + batch_size_per_device: 4 + gradient_accumulation_steps: 1 + gradient_checkpointing: True + max_grad_norm: 1.0 + optim: adamw_torch + weight_decay: 0.0 + scheduler: linear + warmup_steps: 0 + warmup_ratio: 0.0 + accelerate: + use: True + qlora: + use: False + r: 256 + alpha: 256 + dropout: 0.05 + bias: none + target_modules: + - gate + +merge: + num_experts_per_tok: 2 + experts: + - expert_name: ai-tools-expert + model_id: nitic-nlp-team/webnavix-llama-ai-tools + - expert_name: booking-expert + model_id: nitic-nlp-team/webnavix-llama-booking + - expert_name: composing-expert + model_id: nitic-nlp-team/webnavix-llama-composing + - expert_name: information-lookup-expert + model_id: nitic-nlp-team/webnavix-llama-information-lookup + - expert_name: shopping-expert + model_id: nitic-nlp-team/webnavix-llama-shopping + - expert_name: social-interaction-expert + model_id: nitic-nlp-team/webnavix-llama-social-interaction + - expert_name: summarizing-expert + model_id: nitic-nlp-team/webnavix-llama-summarizing + - expert_name: task-management-expert + model_id: nitic-nlp-team/webnavix-llama-task-management + - expert_name: shared-expert + model_id: nitic-nlp-team/webnavix-llama-shared + router_layers: + - gate_proj + - up_proj + - down_proj + +eval: + split: valid + domain: False + batch_size_per_device: 4 + gradient_accumulation_steps: 1 + result_dir: ${project_dir}/results/${project_name}/${eval.split}/${model.name} + load_from_save_dir: True + +huggingface: + token: ${oc.env:HF_TOKEN} + +wandb: + project: ${oc.env:WANDB_PROJECT} + key: ${oc.env:WANDB_API_KEY} + +hydra: + run: + dir: ${project_dir}/logs/${project_name}/${hydra.job.name}/${now:%Y-%m-%d-%H:%M:%S} + sweep: + dir: ${hydra.run.dir} + job: + chdir: False + verbose: INFO diff --git a/src/llama/eval.py b/src/llama/eval.py new file mode 100644 index 0000000..565af51 --- /dev/null +++ b/src/llama/eval.py @@ -0,0 +1,142 @@ +import json +import logging +from pathlib import Path +from typing import Any + +import huggingface_hub +import hydra +import torch +import wandb +from accelerate import Accelerator +from dotenv import load_dotenv +from mergoo.models.modeling_llama import LlamaForCausalLM +from omegaconf import OmegaConf +from tqdm import tqdm +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, # type: ignore # noqa: PGH003 + PreTrainedTokenizer, + PreTrainedTokenizerFast, + pipeline, # type: ignore # noqa: PGH003 +) +from transformers.pipelines.pt_utils import KeyDataset +from weblinx.utils import set_seed + +load_dotenv() + + +@hydra.main(version_base=None, config_path="conf", config_name="config") +def main(cfg) -> None: + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info(OmegaConf.to_yaml(cfg)) + + huggingface_hub.login(token=cfg.huggingface.token) + wandb.login(key=cfg.wandb.key) + + set_seed(cfg.seed) + + model_save_dir = Path(cfg.model.save_dir).expanduser() + model_save_dir.mkdir(exist_ok=True, parents=True) + result_dir = Path(cfg.eval.result_dir).expanduser() + result_dir.mkdir(parents=True, exist_ok=True) + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + ) + + load_model_name = str(model_save_dir) if cfg.eval.get("load_from_save_dir", False) is True else cfg.model.base_name + + tokenizer = AutoTokenizer.from_pretrained( + load_model_name, + padding_side="left", + trust_remote_code=True, + ) + tokenizer.pad_token = tokenizer.unk_token + + accelerator = Accelerator() if cfg.train.accelerate.use else None + model = ( + LlamaForCausalLM.from_pretrained( + cfg.model.base_name, + device_map={"": accelerator.process_index} if accelerator is not None else "auto", + torch_dtype=torch.bfloat16, + rope_scaling={"type": "dynamic", "factor": 2.0} if cfg.model.use_rope else None, + attn_implementation="flash_attention_2" if cfg.model.use_flash_attention_2 else None, + quantization_config=bnb_config if cfg.train.qlora.use else None, + ) + if cfg.model.moe + else AutoModelForCausalLM.from_pretrained( + cfg.model.base_name, + device_map={"": accelerator.process_index} if accelerator is not None else "auto", + torch_dtype=torch.bfloat16, + rope_scaling={"type": "dynamic", "factor": 2.0} if cfg.model.use_rope else None, + attn_implementation="flash_attention_2" if cfg.model.use_flash_attention_2 else None, + quantization_config=bnb_config if cfg.train.qlora.use else None, + ) + ) + + with Path.open( + model_save_dir.joinpath( + f"{cfg.eval.split}/{cfg.eval.domain if cfg.eval.domain else ''}/input_records.json", + ), + "r", + ) as f: + input_records = json.load(f) + + evaluate(cfg, model, tokenizer, input_records, result_dir) + + +def evaluate( + cfg, + model, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + input_records: list[dict[str, Any]], + result_dir: Path, +) -> None: + key_dataset = KeyDataset(input_records, key="text") # type: ignore # noqa: PGH003 + pipe = pipeline( + "text-generation", + model=model, + tokenizer=tokenizer, + torch_dtype=torch.bfloat16, + ) + + results = [] + with torch.amp.autocast("cuda", dtype=torch.bfloat16): # type: ignore # noqa: PGH003 + pbar = tqdm( + pipe( + key_dataset, + batch_size=cfg.eval.batch_size_per_device, + pad_token_id=tokenizer.unk_token_id, + max_new_tokens=cfg.model.max_out_len, + return_full_text=False, + ), + desc="Generating outputs", + total=len(key_dataset), + ) + for i, out in enumerate(pbar): + input_record = input_records[i] + generated_text = out[0]["generated_text"] + result = { + "demo_name": input_record["demo_name"], + "turn_index": input_record["turn_index"], + "prompt": input_record["prompt"], + "text": input_record["text"], + "output_predicted": generated_text, + "output_target": input_record["output_target"], + "output_target_dict": input_record["output_target_dict"], + } + + results.append(result) + + with Path.open(result_dir.joinpath("results.json"), "w") as f: + json.dump(results, f, indent=2) + + +if __name__ == "__main__": + main() diff --git a/src/llama/merge.py b/src/llama/merge.py new file mode 100644 index 0000000..000d3b9 --- /dev/null +++ b/src/llama/merge.py @@ -0,0 +1,38 @@ +import logging +from pathlib import Path + +import huggingface_hub +import hydra +import torch +from mergoo.compose_experts import ComposeExperts +from omegaconf import OmegaConf +from weblinx.utils import set_seed + + +@hydra.main(config_path="conf", config_name="config", version_base=None) +def main(cfg) -> None: + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info(OmegaConf.to_yaml(cfg)) + + huggingface_hub.login(token=cfg.huggingface.token) + + set_seed(cfg.seed) + + model_save_dir = Path(cfg.model.save_dir).expanduser() + model_save_dir.mkdir(exist_ok=True, parents=True) + + merge_config = { + "model_type": "llama", + "num_experts_per_tok": OmegaConf.to_container(cfg.merge.num_experts_per_tok, resolve=True), # type: ignore # noqa: PGH003 + "experts": OmegaConf.to_container(cfg.merge.experts, resolve=True), # type: ignore # noqa: PGH003 + "router_layers": OmegaConf.to_container(cfg.merge.router_layers, resolve=True), # type: ignore # noqa: PGH003 + } + merger = ComposeExperts(merge_config, torch_dtype=torch.bfloat16) + merger.compose() + merger.save_checkpoint(model_save_dir) + + +if __name__ == "__main__": + main() diff --git a/src/llama/processing.py b/src/llama/processing.py new file mode 100644 index 0000000..5f1986f --- /dev/null +++ b/src/llama/processing.py @@ -0,0 +1,397 @@ +"""The processing module contains functions to format the dataset and build input records.""" + +from collections.abc import Callable +from copy import deepcopy +from functools import partial + +import lxml.html +import weblinx.utils.format as wlf +from transformers import ( + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) +from weblinx import ( + Demonstration, + Replay, + Turn, +) +from weblinx.processing.dom import clean_and_prune_tree +from weblinx.processing.prompt import ( + find_turns_with_instructor_chat, + format_utterances, + get_speaker, + multi_attempt_format_prev_turns_truncated, +) +from weblinx.processing.truncation import ( + multi_attempt_truncate_cands_turn, + multi_attempt_truncate_dom_tree, + reduce_list_of_lengths, + truncate_text_at_center, +) +from weblinx.utils.recs import get_list_from_records_by_key + + +def build_formatter_for_multichoice() -> Callable: # noqa: D103 + format_click = partial(wlf.format_click, formatters=(wlf.format_uid,)) + format_text_input = partial( + wlf.format_text_input, + formatters=( + partial(wlf.format_arg_item, name="text", max_length=200), + wlf.format_uid, + ), + ) + format_change = partial( + wlf.format_change, + formatters=( + partial(wlf.format_arg_item, name="value", max_length=200), + wlf.format_uid, + ), + ) + format_submit = partial(wlf.format_submit, formatters=(wlf.format_uid,)) + format_load = partial( + wlf.format_load, + include_transition=False, + include_timestamp=False, + max_length=200, + ) + format_scroll = partial(wlf.format_scroll, include_timestamp=False) + + format_say = partial(wlf.format_say, include_timestamp=False) + + format_intent_auto = partial( + wlf.format_intent_automatically, + format_change=format_change, + format_click=format_click, + format_load=format_load, + format_say=format_say, + format_scroll=format_scroll, + format_submit=format_submit, + format_text_input=format_text_input, + ) + + return format_intent_auto + + +def __get_system_prompt_template_for_llama_mc_concise() -> str: + sys_prompt_template = ( + "You are an AI assistant with a deep understanding of HTML " + "and you must predict actions based on a user request, which will be executed. " + "Use one of the following, replacing [] with an appropriate value: " + "change(value=[str], uid=[str]) ; " + "click(uid=[str]) ; " + "load(url=[str]) ; " + 'say(speaker="navigator", utterance=[str]) ; ' + "scroll(x=[int], y=[int]) ; " + "submit(uid=[str]) ;" + "text_input(text=[str], uid=[str]) ;\n" + "The user's first and last {num_utterances} utterances are: " + "{utterance_context} ;\n" + "Viewport size: {height}h x {width}w ;\n" + "Only the last {num_prev_turns} turns are provided." + ) + + return sys_prompt_template + + +def __get_candidate_prompt_template_for_llama() -> str: + return "Here are the top candidates for this turn:\n{candidate_str}" + + +def __get_final_user_message() -> str: + return ( + "Please select the best action using the correct format, " + "do not provide any other information or explanation." + ) + + +def __merge_prev_turns(prev_turns_text_list: list[str], final_user_message: str) -> list[dict[str, str]]: + prev_turns_merged: list[dict[str, str]] = [] + + for i, turn_text in enumerate(prev_turns_text_list): + role = get_speaker( + turn_text, + instructor_name="user", + navigator_name="assistant", + default_name="unknown", + ) + + if i > 0 and prev_turns_merged[-1]["role"] == role: + prev_turns_merged[-1]["content"] += " " + turn_text + else: + prev_turns_merged.append({"role": role, "content": turn_text}) + + if len(prev_turns_merged) > 0 and prev_turns_merged[-1]["role"] == "user": + prev_turns_merged[-1]["content"] += " " + final_user_message + else: + prev_turns_merged.append({"role": "user", "content": final_user_message}) + + return prev_turns_merged + + +def __format_utterances_truncated( # noqa: ANN202, PLR0913 + turns: list["Turn"], + tokenizer: "PreTrainedTokenizer", + max_tokens: int, + format_utterances_fn, + num_utterances: int = 5, + type_filter="chat", + sep=" ", + convert_to_minutes=True, + template="[{timestamp}] {utterance}", + allow_iterative_reduction=False, +): + utterances = format_utterances_fn( + turns, + num_utterances=num_utterances, + type_filter=type_filter, + sep=None, + convert_to_minutes=convert_to_minutes, + template=template, + ) + if isinstance(utterances, str): + utterances = [utterances] + + utterances_str = " ".join(utterances) if sep is None else str(sep).join(utterances) + utter_tokens = tokenizer.tokenize(utterances_str, add_special_tokens=False) + num_tokens_to_remove = len(utter_tokens) - max_tokens + + records = [] + for i, text in enumerate(utterances): + tokens = tokenizer.tokenize(text, add_special_tokens=False) + records.append( + { + "index": i, + "text": text, + "tokens": tokens, + "length": len(tokens), + }, + ) + + # NOTE: We only count the token lengths of the values, not the entire formatted string. + # The full string may have additional tokens. (key, separator, etc.) + # Consequently, max_total_length is different from max_tokens. + records = sorted(records, key=lambda r: r["length"]) + lengths_orig = get_list_from_records_by_key(records, "length") # type: ignore # noqa: PGH003 + max_total_length = sum(lengths_orig) - num_tokens_to_remove + lengths_reduced = reduce_list_of_lengths(lengths_orig, max_length=max_total_length) + + for i, rec in enumerate(records): + red_length = lengths_reduced[i] + + # NOTE: If the length is the same, then we don't need to do anything. + # Otherwise, we need to truncate the text. + if red_length >= rec["length"]: + continue + + trunc = truncate_text_at_center( + rec["text"], + tokenizer=tokenizer, + max_tokens=red_length, + allow_iterative_reduction=allow_iterative_reduction, + ) + + utterances[rec["index"]] = trunc["text"] + + if sep is None: + return utterances + + return sep.join(utterances) + + +def __format_candidates(candidates, max_char_len=300, use_uid_as_rank=False): # noqa: ANN202 + s = "" + for cand in candidates: + doc = cand["doc"].replace("\n", " ").rstrip() + rank = "uid = " + cand["uid"] if use_uid_as_rank else cand["rank"] + + if max_char_len is not None and len(doc) > max_char_len: + doc = doc[: max_char_len - 3] + "..." + + s += f"({rank}) {doc}\n" + + return s + + +def build_prompt_records_for_llama_truncated( # noqa: D103, PLR0913 + replay: Replay, + turn: Turn, + format_intent, + tokenizer: PreTrainedTokenizer, + cands_turn=None, + num_utterances: int = 5, + num_prev_turns: int = 5, + system_prompt_template=None, + candidate_prompt_template=None, + final_user_message=None, + include_html=True, + format_candidates_fn=partial( # noqa: B008 + __format_candidates, + max_char_len=None, # type: ignore # noqa: PGH003 + use_uid_as_rank=True, + ), + merge_prev_turns_fn=__merge_prev_turns, + format_output_dict_fn: Callable = partial( # noqa: B008 + wlf.format_output_dictionary, + function_key="intent", + ), + max_html_tokens: int = 700, + max_utterance_tokens: int = 40 * 5, + max_prev_turns_tokens: int = 50 * 5, + max_candidates_tokens: int = 65 * 10, + add_unused_len_to_cands: bool = True, + allow_iterative_reduction: bool = False, + use_tokenizer_template: bool = False, + template_tokenizer=None, + parser=None, +) -> list[dict[str, str]]: + if system_prompt_template is None: + system_prompt_template = __get_system_prompt_template_for_llama_mc_concise() + + if candidate_prompt_template is None: + candidate_prompt_template = __get_candidate_prompt_template_for_llama() + + if final_user_message is None: + final_user_message = __get_final_user_message() + + instructor_chat_turns = find_turns_with_instructor_chat( + replay, + turn, + num_prev_turns=num_prev_turns, + ) + utterance_context = __format_utterances_truncated( + instructor_chat_turns, + tokenizer=tokenizer, + max_tokens=max_utterance_tokens, + num_utterances=num_utterances, + format_utterances_fn=format_utterances, + allow_iterative_reduction=allow_iterative_reduction, + ) + + prev_turns_text_list = multi_attempt_format_prev_turns_truncated( + replay=replay, + turn=turn, + format_intent=partial(format_intent, return_as=dict), + tokenizer=tokenizer, + num_prev_turns=num_prev_turns, + turn_sep=None, # type: ignore # noqa: PGH003 + max_tokens=max_prev_turns_tokens, + max_attempts=5, + format_output_dict_fn=format_output_dict_fn, + warn_after_attempts=False, + allow_iterative_reduction=allow_iterative_reduction, + ) + + prev_turns_merged = merge_prev_turns_fn( + prev_turns_text_list=prev_turns_text_list, + final_user_message=final_user_message, + ) + + sys_prompt = system_prompt_template.format( + num_utterances=num_utterances - 1, # NOTE: 1 less since we add the first utterance. + utterance_context=utterance_context, + height=turn.viewport_height, + width=turn.viewport_width, + num_prev_turns=num_prev_turns, + ) + + if include_html and turn.html not in ["", None] and cands_turn is not None: + dom_tree_raw = lxml.html.fromstring(turn.html, parser=parser) + dom_tree_pruned = clean_and_prune_tree(dom_tree_raw, cands_turn=cands_turn) + trunc = multi_attempt_truncate_dom_tree( + dom_tree=dom_tree_pruned, + tokenizer=tokenizer, + max_tokens=max_html_tokens, + warn_after_attempts=False, + allow_iterative_reduction=allow_iterative_reduction, + ) + html = trunc["tree_repr"] + sys_prompt = html + "\n" + sys_prompt + else: + html = "" + + if cands_turn is not None: + if add_unused_len_to_cands: + # NOTE: Add the unused length to the candidates. + num_html_tokens = len(tokenizer.tokenize(html)) + num_utter_tokens = len(tokenizer.tokenize(utterance_context)) # type: ignore # noqa: PGH003 + if use_tokenizer_template: + if template_tokenizer is None: + msg = "template_tokenizer must be provided when use_tokenizer_template is True." + raise ValueError(msg) + prev_turns_merged_copy = deepcopy(prev_turns_merged) + if prev_turns_merged[0]["role"] == "assistant": + prev_turns_merged_copy.insert(0, {"role": "user", "content": ""}) + num_prev_turns_tokens = len( + template_tokenizer.apply_chat_template( + [{"role": "system", "content": ""}, *prev_turns_merged_copy], + tokenize=True, + ), + ) + else: + num_prev_turns_tokens = len( + tokenizer.tokenize(" ".join(prev_turns_text_list)), + ) + remain_html_tokens = max_html_tokens - num_html_tokens + remain_utter_tokens = max_utterance_tokens - num_utter_tokens + remain_prev_turns_tokens = max_prev_turns_tokens - num_prev_turns_tokens + remain_tokens = remain_html_tokens + remain_utter_tokens + remain_prev_turns_tokens + # NOTE: Add the unused length to the max_candidates_tokens. + max_candidates_tokens += remain_tokens + + cands_turn_trunc = multi_attempt_truncate_cands_turn( + cands_turn=cands_turn, + tokenizer=tokenizer, + max_tokens=max_candidates_tokens, + format_candidates_fn=format_candidates_fn, + warn_after_attempts=False, + allow_iterative_reduction=allow_iterative_reduction, + ) + cand_str = format_candidates_fn(cands_turn_trunc, max_char_len=None) # type: ignore # noqa: PGH003 + cand_prompt = candidate_prompt_template.format(candidate_str=cand_str) + sys_prompt += cand_prompt[:-1] + + return [{"role": "system", "content": sys_prompt}, *prev_turns_merged] + + +def __insert_empty_user_content_at_first(prompt: list) -> None: + if prompt[0]["role"] != "system": + msg = f"First prompt must be a system prompt. Got {prompt[0]['role']} instead." + raise ValueError(msg) + + if prompt[1]["role"] != "user": + prompt.insert(1, {"role": "user", "content": ""}) + + +def insert_formatted_chat_into_records( # noqa: D103 + records, + demos: list[Demonstration], + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + *, + include_output_target: bool = True, +) -> list: + processed_records = deepcopy(records) + for i, record in enumerate(records): + __insert_empty_user_content_at_first(record["prompt"]) + + if include_output_target: + target = [{"role": "assistant", "content": record["output_target"]}] + combined = record["prompt"] + target + else: + combined = record["prompt"] + + # NOTE: The `apply_chat_template` method of the tokenizer is required. + text = str( + tokenizer.apply_chat_template( + combined, + tokenize=False, + add_generation_prompt=False, + ), + ) + + processed_records[i]["text"] = text + + processed_records[i]["tasks"] = next( + filter(lambda demo: demo.form["shortcode"] == record["demo_name"], demos), + ).form["tasks"] + + return processed_records diff --git a/src/llama/train.py b/src/llama/train.py new file mode 100644 index 0000000..064254e --- /dev/null +++ b/src/llama/train.py @@ -0,0 +1,185 @@ +import json +import logging +from pathlib import Path +from typing import Any + +import huggingface_hub +import hydra +import torch +import wandb +from accelerate import Accelerator +from datasets import Dataset +from dotenv import load_dotenv +from mergoo.models.modeling_llama import LlamaForCausalLM +from omegaconf import OmegaConf +from peft import LoraConfig # type: ignore # noqa: PGH003 +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, # type: ignore # noqa: PGH003 + PreTrainedTokenizer, + PreTrainedTokenizerFast, + TrainingArguments, +) +from trl import SFTTrainer +from weblinx.utils import set_seed + +load_dotenv() + + +@hydra.main(config_path="conf", config_name="config", version_base=None) +def main(cfg) -> None: + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info(OmegaConf.to_yaml(cfg)) + + huggingface_hub.login(token=cfg.huggingface.token) + wandb.login(key=cfg.wandb.key) + + set_seed(cfg.seed) + + model_save_dir = Path(cfg.model.save_dir).expanduser() + model_save_dir.mkdir(exist_ok=True, parents=True) + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + ) + + tokenizer = AutoTokenizer.from_pretrained( + cfg.model.base_name, + padding_side="right", + trust_remote_code=True, + ) + tokenizer.pad_token = tokenizer.unk_token + + accelerator = Accelerator() if cfg.train.accelerate.use else None + model = ( + LlamaForCausalLM.from_pretrained( + cfg.model.base_name, + device_map={"": accelerator.process_index} if accelerator is not None else "auto", + torch_dtype=torch.bfloat16, + use_cache=False, + attn_implementation="flash_attention_2" if cfg.model.use_flash_attention_2 else None, + quantization_config=bnb_config if cfg.train.qlora.use else None, + ) + if cfg.model.moe + else AutoModelForCausalLM.from_pretrained( + cfg.model.base_name, + device_map={"": accelerator.process_index} if accelerator is not None else "auto", + torch_dtype=torch.bfloat16, + use_cache=False, + attn_implementation="flash_attention_2" if cfg.model.use_flash_attention_2 else None, + quantization_config=bnb_config if cfg.train.qlora.use else None, + ) + ) + if cfg.model.freeze.use: + for name, param in model.named_parameters(): # type: ignore # noqa: PGH003 + if any(layer in name for layer in cfg.model.freeze.trainable_layers): + param.requires_grad = True + else: + param.requires_grad = False + + with Path.open( + model_save_dir.joinpath( + f"{cfg.train.split}/{cfg.train.domain if cfg.train.domain else ''}/input_records.json", + ), + "r", + ) as f: + train_input_records = json.load(f) + with Path.open( + model_save_dir.joinpath( + f"{cfg.eval.split}/{cfg.eval.domain if cfg.eval.domain else ''}/input_records.json", + ), + "r", + ) as f: + eval_input_records = json.load(f) + + train_input_texts = [{"text": record["text"]} for record in train_input_records] + eval_input_texts = [{"text": record["text"]} for record in eval_input_records] + + train( + cfg, + model, + tokenizer, + train_input_texts, + eval_input_texts, + model_save_dir, + ) + + +def train( # noqa: PLR0913 + cfg, + model, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + train_input_texts: list[dict[str, Any]], + eval_input_texts: list[dict[str, Any]], + model_save_dir: Path, +) -> None: + peft_config = LoraConfig( + r=cfg.train.qlora.r, + lora_alpha=cfg.train.qlora.alpha, + lora_dropout=cfg.train.qlora.dropout, + bias=cfg.train.qlora.bias, + task_type="CAUSAL_LM", + target_modules=OmegaConf.to_container(cfg.train.qlora.target_modules, resolve=True), # type: ignore # noqa: PGH003 + ) + + training_args = TrainingArguments( + output_dir=str(model_save_dir), + num_train_epochs=cfg.train.num_epochs, + learning_rate=cfg.train.learning_rate, + per_device_train_batch_size=cfg.train.batch_size_per_device, + gradient_accumulation_steps=cfg.train.gradient_accumulation_steps, + gradient_checkpointing=cfg.train.gradient_checkpointing, + gradient_checkpointing_kwargs={"use_reentrant": False}, + max_grad_norm=cfg.train.max_grad_norm, + optim=cfg.train.optim, + weight_decay=cfg.train.weight_decay, + lr_scheduler_type=cfg.train.scheduler, + warmup_steps=cfg.train.warmup_steps, + warmup_ratio=cfg.train.warmup_ratio, + save_strategy="steps", + save_steps=100, + eval_strategy="steps", + per_device_eval_batch_size=cfg.eval.batch_size_per_device, + eval_accumulation_steps=cfg.eval.gradient_accumulation_steps, + eval_steps=100, + logging_strategy="steps", + logging_steps=10, + logging_first_step=True, + bf16=True, + bf16_full_eval=True, + group_by_length=True, + prediction_loss_only=True, + metric_for_best_model="eval_loss", + load_best_model_at_end=True, + torch_compile=True, + report_to="wandb", + ) # type: ignore # noqa: PGH003 + + trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + args=training_args, # type: ignore # noqa: PGH003 + train_dataset=Dataset.from_list(train_input_texts), + eval_dataset=Dataset.from_list(eval_input_texts), + max_seq_length=model.config.max_position_embeddings, + dataset_text_field="text", + peft_config=peft_config if cfg.train.qlora.use else None, + ) + + wandb.init(project=cfg.wandb.project) + + trainer.train() # type: ignore # noqa: PGH003 + + trainer.save_model(str(model_save_dir)) + tokenizer.save_pretrained(model_save_dir) + trainer.state.save_to_json(str(model_save_dir.joinpath("trainer_state.json"))) + + +if __name__ == "__main__": + main() diff --git a/webnavix.code-workspace b/webnavix.code-workspace new file mode 100644 index 0000000..a26ff12 --- /dev/null +++ b/webnavix.code-workspace @@ -0,0 +1,8 @@ +{ + "folders": [ + { + "name": "webnavix", + "path": "." + } + ] +}