From f7d4a90300b53db75a0b3e3b751a96641a970574 Mon Sep 17 00:00:00 2001
From: shio <85730998+dino3616@users.noreply.github.com>
Date: Wed, 13 Nov 2024 07:41:30 +0900
Subject: [PATCH] =?UTF-8?q?initial:=20=F0=9F=8E=89=20first=20commit?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.devcontainer/devcontainer.json | 20 ++
.dockerignore | 30 ++
.editorconfig | 12 +
.env.example | 5 +
.github/workflows/app-test.yaml | 19 ++
.gitignore | 33 ++
.python-version | 1 +
.vscode/settings.json | 187 +++++++++++
LICENSE | 201 ++++++++++++
README.md | 100 ++++++
docker/Dockerfile.development | 26 ++
docker/docker-compose.development.yaml | 13 +
lefthook.yaml | 9 +
pyproject.toml | 37 +++
requirements-dev.lock | 264 +++++++++++++++
requirements.lock | 262 +++++++++++++++
ruff.toml | 48 +++
src/__init__.py | 0
src/dataset.py | 27 ++
src/llama/__init__.py | 0
src/llama/build.py | 95 ++++++
src/llama/conf/accelerate/deepspeed.yaml | 13 +
src/llama/conf/accelerate/fsdp.yaml | 21 ++
src/llama/conf/deepspeed/zero2.json | 14 +
src/llama/conf/variant/base.yaml | 117 +++++++
src/llama/conf/variant/expert.yaml | 117 +++++++
src/llama/conf/variant/merge.yaml | 117 +++++++
src/llama/conf/variant/moe.yaml | 107 ++++++
src/llama/eval.py | 142 ++++++++
src/llama/merge.py | 38 +++
src/llama/processing.py | 397 +++++++++++++++++++++++
src/llama/train.py | 185 +++++++++++
webnavix.code-workspace | 8 +
33 files changed, 2665 insertions(+)
create mode 100644 .devcontainer/devcontainer.json
create mode 100644 .dockerignore
create mode 100644 .editorconfig
create mode 100644 .env.example
create mode 100644 .github/workflows/app-test.yaml
create mode 100644 .gitignore
create mode 100644 .python-version
create mode 100644 .vscode/settings.json
create mode 100644 LICENSE
create mode 100644 README.md
create mode 100644 docker/Dockerfile.development
create mode 100644 docker/docker-compose.development.yaml
create mode 100644 lefthook.yaml
create mode 100644 pyproject.toml
create mode 100644 requirements-dev.lock
create mode 100644 requirements.lock
create mode 100644 ruff.toml
create mode 100644 src/__init__.py
create mode 100644 src/dataset.py
create mode 100644 src/llama/__init__.py
create mode 100644 src/llama/build.py
create mode 100644 src/llama/conf/accelerate/deepspeed.yaml
create mode 100644 src/llama/conf/accelerate/fsdp.yaml
create mode 100644 src/llama/conf/deepspeed/zero2.json
create mode 100644 src/llama/conf/variant/base.yaml
create mode 100644 src/llama/conf/variant/expert.yaml
create mode 100644 src/llama/conf/variant/merge.yaml
create mode 100644 src/llama/conf/variant/moe.yaml
create mode 100644 src/llama/eval.py
create mode 100644 src/llama/merge.py
create mode 100644 src/llama/processing.py
create mode 100644 src/llama/train.py
create mode 100644 webnavix.code-workspace
diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
new file mode 100644
index 0000000..27baf19
--- /dev/null
+++ b/.devcontainer/devcontainer.json
@@ -0,0 +1,20 @@
+{
+ "name": "webnavix",
+ "workspaceFolder": "/workspaces/webnavix/",
+ "dockerComposeFile": ["../docker/docker-compose.development.yaml"],
+ "service": "app",
+ "customizations": {
+ "vscode": {
+ "extensions": [
+ "adam-bender.commit-message-editor",
+ "charliermarsh.ruff",
+ "eamodio.gitlens",
+ "EditorConfig.EditorConfig",
+ "esbenp.prettier-vscode",
+ "ms-python.python",
+ "tamasfe.even-better-toml",
+ "VisualStudioExptTeam.vscodeintellicode"
+ ]
+ }
+ }
+}
diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000..3f93c3b
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,30 @@
+# cache
+**/__pycache__/
+**/.cache/
+**/.mypy_cache/
+**/.ruff_cache/
+**/*.egg-info/
+
+# dataset
+**/wl_data/
+
+# debug
+**/*log*
+**/wandb/
+
+# deliverable
+**/build/
+**/dist/
+**/out/
+**/results/
+
+# dependency
+**/.venv/
+
+# env file
+**/.env*
+!**/.env.example
+
+# misc
+**/.DS_Store
+**/*.pem
diff --git a/.editorconfig b/.editorconfig
new file mode 100644
index 0000000..4a7ea30
--- /dev/null
+++ b/.editorconfig
@@ -0,0 +1,12 @@
+root = true
+
+[*]
+indent_style = space
+indent_size = 2
+end_of_line = lf
+charset = utf-8
+trim_trailing_whitespace = true
+insert_final_newline = true
+
+[*.md]
+trim_trailing_whitespace = false
diff --git a/.env.example b/.env.example
new file mode 100644
index 0000000..cfe222a
--- /dev/null
+++ b/.env.example
@@ -0,0 +1,5 @@
+COMPOSE_PROJECT_NAME="webnavix"
+HF_TOKEN=""
+PROJECT_DIR="/workspaces/webnavix/"
+WANDB_API_KEY=""
+WANDB_PROJECT=""
diff --git a/.github/workflows/app-test.yaml b/.github/workflows/app-test.yaml
new file mode 100644
index 0000000..96dc403
--- /dev/null
+++ b/.github/workflows/app-test.yaml
@@ -0,0 +1,19 @@
+name: app test
+
+on: push
+
+jobs:
+ test:
+ runs-on: ubuntu-latest
+ steps:
+ - name: checkout
+ uses: actions/checkout@v4
+
+ - name: setup rye
+ uses: eifinger/setup-rye@v4
+
+ - name: install dependencies
+ run: rye sync
+
+ - name: check
+ run: rye check
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..6c3c3c9
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,33 @@
+# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
+
+# cache
+__pycache__/
+.cache/
+.mypy_cache/
+.ruff_cache/
+*.egg-info/
+
+# dataset
+wl_data
+
+# debug
+*log*
+wandb/
+
+# deliverable
+build/
+dist/
+out/
+results/
+checkpoints
+
+# dependency
+.venv/
+
+# env file
+.env*
+!.env.example
+
+# misc
+.DS_Store
+*.pem
diff --git a/.python-version b/.python-version
new file mode 100644
index 0000000..e4fba21
--- /dev/null
+++ b/.python-version
@@ -0,0 +1 @@
+3.12
diff --git a/.vscode/settings.json b/.vscode/settings.json
new file mode 100644
index 0000000..f7e8142
--- /dev/null
+++ b/.vscode/settings.json
@@ -0,0 +1,187 @@
+{
+ "commit-message-editor.tokens": [
+ {
+ "label": "Type",
+ "name": "type",
+ "type": "enum",
+ "description": "Type of changes.",
+ "combobox": true,
+ "options": [
+ {
+ "label": "feat: โจ",
+ "value": "feat: โจ",
+ "description": "Implementation of new features."
+ },
+ {
+ "label": "feat: ๐",
+ "value": "feat: ๐",
+ "description": "Repair of existing features."
+ },
+ {
+ "label": "feat: โฐ๏ธ",
+ "value": "feat: โฐ๏ธ",
+ "description": "Deletion of features."
+ },
+ {
+ "label": "fix: ๐",
+ "value": "fix: ๐",
+ "description": "Bug fixes."
+ },
+ {
+ "label": "fix: ๐๏ธ",
+ "value": "fix: ๐๏ธ",
+ "description": "Critical bug fixes or major changes."
+ },
+ {
+ "label": "doc: ๐",
+ "value": "doc: ๐",
+ "description": "Documentation changes."
+ },
+ {
+ "label": "typo: ๐๏ธ",
+ "value": "typo: ๐๏ธ",
+ "description": "Typography changes."
+ },
+ {
+ "label": "style: ๐",
+ "value": "style: ๐",
+ "description": "Style changes."
+ },
+ {
+ "label": "refactor: โป๏ธ",
+ "value": "refactor: โป๏ธ",
+ "description": "Code formatting or refactoring."
+ },
+ {
+ "label": "test: ๐งช",
+ "value": "test: ๐งช",
+ "description": "Test cases changes."
+ },
+ {
+ "label": "ci: ๐ฆบ",
+ "value": "ci: ๐ฆบ",
+ "description": "CI changes."
+ },
+ {
+ "label": "build: ๐ฆ๏ธ",
+ "value": "build: ๐ฆ๏ธ",
+ "description": "Build system or dependency changes."
+ },
+ {
+ "label": "container: ๐ณ",
+ "value": "container: ๐ณ",
+ "description": "The Dockerfile changes."
+ },
+ {
+ "label": "container: ๐",
+ "value": "container: ๐",
+ "description": "The docker-compose changes."
+ },
+ {
+ "label": "chore: ๐ง",
+ "value": "chore: ๐ง",
+ "description": "Configuration changes."
+ },
+ {
+ "label": "chore: ๐จ",
+ "value": "chore: ๐จ",
+ "description": "Development script changes."
+ },
+ {
+ "label": "chore: ๐ฑ",
+ "value": "chore: ๐ฑ",
+ "description": "Assets changes."
+ },
+ {
+ "label": "revert: โช๏ธ",
+ "value": "revert: โช๏ธ",
+ "description": "Reversion of changes."
+ },
+ {
+ "label": "wip: ๐ง",
+ "value": "wip: ๐ง",
+ "description": "Changes that will be squashed."
+ },
+ {
+ "label": "initial: ๐",
+ "value": "initial: ๐",
+ "description": "The first commit."
+ }
+ ]
+ },
+ {
+ "label": "Scope",
+ "name": "scope",
+ "type": "text",
+ "description": "Scope of changes.",
+ "prefix": " (",
+ "suffix": ")"
+ },
+ {
+ "label": "Short Description",
+ "name": "description",
+ "type": "text",
+ "description": "Commit summary.",
+ "prefix": " "
+ },
+ {
+ "label": "Body",
+ "name": "body",
+ "type": "text",
+ "description": "Detailed description of commit.",
+ "maxLines": 10,
+ "multiline": true,
+ "lines": 5
+ },
+ {
+ "label": "Footer",
+ "name": "footer",
+ "description": "Description of disruptive changes or signature.",
+ "type": "text",
+ "multiline": true
+ }
+ ],
+ "commit-message-editor.dynamicTemplate": [
+ "{type}{scope}{description}",
+ "",
+ "{body}",
+ "",
+ "{footer}"
+ ],
+ "commit-message-editor.staticTemplate": [
+ "label: emoji (scope) short-description",
+ "",
+ "body",
+ "",
+ "footer"
+ ],
+ "commit-message-editor.view.defaultView": "form",
+ "editor.defaultFormatter": "esbenp.prettier-vscode",
+ "files.encoding": "utf8",
+ "files.eol": "\n",
+ "python.analysis.exclude": [
+ "**/__pycache__/",
+ "**/.cache/",
+ "**/.mypy_cache/",
+ "**/.ruff_cache/",
+ "**/.venv/",
+ "**/*.egg-info/",
+ "**/build/",
+ "**/checkpoints/",
+ "**/dist/",
+ "**/out/",
+ "**/results/",
+ "**/wandb/",
+ "**/wl_data/",
+ "**/*log/*"
+ ],
+ "python.analysis.typeCheckingMode": "basic",
+ "python.defaultInterpreterPath": "/opt/rye/shims/python",
+ "[python]": {
+ "editor.codeActionsOnSave": {
+ "source.fixAll.ruff": "explicit",
+ "source.organizeImports.ruff": "explicit"
+ },
+ "editor.defaultFormatter": "charliermarsh.ruff"
+ }
+}
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..8edc12a
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2024 NITIC-NLP-Team
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..3ef83e5
--- /dev/null
+++ b/README.md
@@ -0,0 +1,100 @@
+# WebNavix: Continuous Generalist Web Navigation Agent using Domain-wise Mixture-of-Experts
+
+WebNavix is a continuous generalist web navigation agent that merges individually fine-tuned LLMs as domain experts.
+
+## Core Contributors ๐ ๏ธ
+
+| shio | ituki |
+| :--------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------: |
+| [](https://github.com/dino3616) | [](https://github.com/ituki0426) |
+| `#repository-owner` `#main-author` `#model-composer` | `#co-author` `#model-analyst` |
+
+## Setup with Dev Containers ๐ฆ
+
+You can easily launch the development environment of WebNavix with Dev Containers.
+Here is the step-by-step guide.
+
+### Attention
+
+- You need to install [Docker](https://docs.docker.com/get-docker) and [VSCode](https://code.visualstudio.com) before.
+
+### 1. clone git repository
+
+```bash
+git clone "https://github.com/nitic-nlp-team/webnavix" && cd "./webnavix"
+```
+
+### 2. set environment variables
+
+See `.env.example` or contact the [repository owner](https://github.com/dino3616) for more details.
+
+### 3. launch dev containers
+
+Launch containers using the VSCode extension [Dev Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers).
+
+### 4. pin python version
+
+```bash
+rye pin $(cat "./.python-version")
+```
+
+### 5. install dependencies
+
+```bash
+rye sync
+```
+
+### 6. activate virtual environment
+
+```bash
+source "./.venv/bin/activate"
+```
+
+### 7. install FlashAttention-2
+
+```bash
+uv pip install flash-attn --no-build-isolation
+```
+
+## Setup locally ๐ฅ๏ธ
+
+If you want to build an environment more quickly without Docker, you can follow these steps to build your environment locally.
+
+### Attention
+
+- You need to install [rye](https://rye.astral.sh/guide/installation) before.
+- [Optional] You should install project recommended VSCode extensions that specified in [`.devcontainer/devcontainer.json`](./.devcontainer/devcontainer.json#L8C7-L17C8) before.
+
+### 1. clone git repository
+
+```bash
+git clone "https://github.com/nitic-nlp-team/webnavix" && cd "./webnavix"
+```
+
+### 2. set environment variables
+
+See `.env.example` or contact the [repository owner](https://github.com/dino3616) for more details.
+
+### 3. pin python version
+
+```bash
+rye pin $(cat "./.python-version")
+```
+
+### 4. install dependencies
+
+```bash
+rye sync
+```
+
+### 5. activate virtual environment
+
+```bash
+source "./.venv/bin/activate"
+```
+
+### 6. install FlashAttention-2
+
+```bash
+uv pip install flash-attn --no-build-isolation
+```
diff --git a/docker/Dockerfile.development b/docker/Dockerfile.development
new file mode 100644
index 0000000..4c54fe1
--- /dev/null
+++ b/docker/Dockerfile.development
@@ -0,0 +1,26 @@
+FROM python:3.12
+
+SHELL ["/bin/bash", "-o", "pipefail", "-c"]
+
+# hadolint ignore=DL3008
+RUN apt-get update \
+ && apt-get --no-install-recommends -y install git gnupg2 ca-certificates curl pipx \
+ && pipx ensurepath \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists
+
+RUN curl -sSf https://rye.astral.sh/get | RYE_INSTALL_OPTION="--yes" bash \
+ && echo "source '$HOME/.rye/env'" >> ~/.bashrc \
+ && /root/.rye/shims/rye config --set-bool behavior.global-python=true \
+ && /root/.rye/shims/rye config --set-bool behavior.use-uv=true
+
+RUN RYE_UV_HOME=$(find "$HOME/.rye/uv" -type d -regex '.*/[0-9]+\.[0-9]+\.[0-9]+$') \
+ && echo "export PATH=\"$RYE_UV_HOME:\$PATH\"" >> ~/.bashrc
+
+WORKDIR /workspaces/webnavix/
+
+COPY ./.python-version ./pyproject.toml ./requirements* ./
+# hadolint ignore=SC1091
+RUN "$HOME/.rye/shims/rye" pin "$(cat ./.python-version)" && "$HOME/.rye/shims/rye" sync && source ./.venv/bin/activate
+
+COPY ./ ./
diff --git a/docker/docker-compose.development.yaml b/docker/docker-compose.development.yaml
new file mode 100644
index 0000000..4d2eb25
--- /dev/null
+++ b/docker/docker-compose.development.yaml
@@ -0,0 +1,13 @@
+services:
+ app:
+ container_name: webnavix
+ build:
+ context: ../
+ dockerfile: ./docker/Dockerfile.development
+ volumes:
+ - type: bind
+ source: ../
+ target: /workspaces/webnavix/
+ environment:
+ PROJECT_DIR: /workspaces/webnavix/
+ tty: true
diff --git a/lefthook.yaml b/lefthook.yaml
new file mode 100644
index 0000000..9de67fa
--- /dev/null
+++ b/lefthook.yaml
@@ -0,0 +1,9 @@
+pre-commit:
+ parallel: true
+ commands:
+ check-py:
+ glob: "*.*{py}*"
+ run: ruff check --fix {staged_files}
+ format-py:
+ glob: "*.*{py}*"
+ run: ruff format --fix {staged_files}
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..956d8a0
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,37 @@
+[project]
+name = "webnavix"
+version = "0.1.0"
+description = "Continuous Generalist Web Navigation Agent using domain-wise Mixture-of-Experts that merges individually fine-tuned LLMs as domain experts."
+authors = [
+ { name = "shio", email = "85730998+dino3616@users.noreply.github.com" },
+]
+dependencies = [
+ "accelerate==0.27.2",
+ "bitsandbytes==0.42.0",
+ "datasets==2.19.2",
+ "deepspeed==0.15.1",
+ "huggingface-hub==0.24.6",
+ "hydra-core==1.3.2",
+ "lxml==5.3.0",
+ "mergoo==0.0.10",
+ "omegaconf==2.3.0",
+ "peft==0.12.0",
+ "python-dotenv==1.0.1",
+ "torch==2.4.1",
+ "tqdm==4.66.2",
+ "transformers==4.42.4",
+ "trl==0.10.1",
+ "wandb==0.17.9",
+ "weblinx[eval]==0.3.0",
+]
+readme = "README.md"
+requires-python = "~=3.12"
+
+[tool.rye]
+managed = true
+dev-dependencies = ["lefthook==0.1.2", "ruff==0.6.4"]
+
+[tool.rye.scripts]
+check = { chain = ["lint", "fmt"] }
+lint = "ruff check ./ --diff"
+fmt = "ruff fmt ./"
diff --git a/requirements-dev.lock b/requirements-dev.lock
new file mode 100644
index 0000000..3b552f2
--- /dev/null
+++ b/requirements-dev.lock
@@ -0,0 +1,264 @@
+# generated by rye
+# use `rye lock` or `rye sync` to update this lockfile
+#
+# last locked with the following flags:
+# pre: false
+# features: []
+# all-features: false
+# with-sources: false
+# generate-hashes: false
+# universal: false
+
+-e file:.
+accelerate==0.27.2
+ # via mergoo
+ # via peft
+ # via trl
+ # via webnavix
+aiohttp==3.9.5
+ # via datasets
+ # via fsspec
+aiosignal==1.3.1
+ # via aiohttp
+annotated-types==0.7.0
+ # via pydantic
+antlr4-python3-runtime==4.9.3
+ # via hydra-core
+ # via omegaconf
+attrs==23.2.0
+ # via aiohttp
+bitsandbytes==0.42.0
+ # via webnavix
+certifi==2024.7.4
+ # via requests
+ # via sentry-sdk
+charset-normalizer==3.3.2
+ # via requests
+click==8.1.7
+ # via wandb
+colorama==0.4.6
+ # via sacrebleu
+datasets==2.19.2
+ # via trl
+ # via webnavix
+deepspeed==0.15.1
+ # via webnavix
+dill==0.3.7
+ # via datasets
+ # via multiprocess
+docker-pycreds==0.4.0
+ # via wandb
+docstring-parser==0.16
+ # via tyro
+filelock==3.15.4
+ # via datasets
+ # via huggingface-hub
+ # via torch
+ # via transformers
+frozenlist==1.4.1
+ # via aiohttp
+ # via aiosignal
+fsspec==2023.10.0
+ # via datasets
+ # via huggingface-hub
+ # via torch
+gitdb==4.0.11
+ # via gitpython
+gitpython==3.1.43
+ # via wandb
+hjson==3.1.0
+ # via deepspeed
+huggingface-hub==0.24.6
+ # via accelerate
+ # via datasets
+ # via mergoo
+ # via peft
+ # via tokenizers
+ # via transformers
+ # via webnavix
+hydra-core==1.3.2
+ # via webnavix
+idna==3.7
+ # via requests
+ # via yarl
+jinja2==3.1.4
+ # via torch
+lefthook==0.1.2
+lxml==5.3.0
+ # via sacrebleu
+ # via webnavix
+markdown-it-py==3.0.0
+ # via rich
+markupsafe==2.1.5
+ # via jinja2
+mdurl==0.1.2
+ # via markdown-it-py
+mergoo==0.0.10
+ # via webnavix
+mpmath==1.3.0
+ # via sympy
+multidict==6.0.5
+ # via aiohttp
+ # via yarl
+multiprocess==0.70.15
+ # via datasets
+networkx==3.3
+ # via torch
+ninja==1.11.1.1
+ # via deepspeed
+numpy==1.26.4
+ # via accelerate
+ # via datasets
+ # via deepspeed
+ # via mergoo
+ # via pandas
+ # via peft
+ # via pyarrow
+ # via sacrebleu
+ # via scipy
+ # via transformers
+ # via trl
+ # via weblinx
+omegaconf==2.3.0
+ # via hydra-core
+ # via webnavix
+packaging==24.1
+ # via accelerate
+ # via datasets
+ # via deepspeed
+ # via huggingface-hub
+ # via hydra-core
+ # via peft
+ # via transformers
+pandas==2.2.2
+ # via datasets
+ # via weblinx
+peft==0.12.0
+ # via mergoo
+ # via webnavix
+platformdirs==4.2.2
+ # via wandb
+portalocker==2.10.1
+ # via sacrebleu
+protobuf==5.27.2
+ # via mergoo
+ # via wandb
+psutil==6.0.0
+ # via accelerate
+ # via deepspeed
+ # via peft
+ # via wandb
+py-cpuinfo==9.0.0
+ # via deepspeed
+pyarrow==17.0.0
+ # via datasets
+pyarrow-hotfix==0.6
+ # via datasets
+pydantic==2.9.2
+ # via deepspeed
+pydantic-core==2.23.4
+ # via pydantic
+pygments==2.18.0
+ # via rich
+python-dateutil==2.9.0.post0
+ # via pandas
+python-dotenv==1.0.1
+ # via webnavix
+pytz==2024.1
+ # via pandas
+pyyaml==6.0.1
+ # via accelerate
+ # via datasets
+ # via huggingface-hub
+ # via omegaconf
+ # via peft
+ # via transformers
+ # via wandb
+regex==2024.5.15
+ # via sacrebleu
+ # via transformers
+requests==2.32.3
+ # via datasets
+ # via fsspec
+ # via huggingface-hub
+ # via transformers
+ # via wandb
+rich==13.7.1
+ # via tyro
+ruff==0.6.4
+sacrebleu==2.4.3
+ # via weblinx
+safetensors==0.4.3
+ # via accelerate
+ # via mergoo
+ # via peft
+ # via transformers
+scipy==1.14.0
+ # via bitsandbytes
+sentencepiece==0.2.0
+ # via mergoo
+sentry-sdk==2.13.0
+ # via wandb
+setproctitle==1.3.3
+ # via wandb
+setuptools==71.1.0
+ # via torch
+ # via wandb
+shtab==1.7.1
+ # via tyro
+six==1.16.0
+ # via docker-pycreds
+ # via python-dateutil
+smmap==5.0.1
+ # via gitdb
+sympy==1.13.1
+ # via torch
+tabulate==0.9.0
+ # via sacrebleu
+tokenizers==0.19.1
+ # via transformers
+torch==2.4.1
+ # via accelerate
+ # via deepspeed
+ # via mergoo
+ # via peft
+ # via trl
+ # via webnavix
+tqdm==4.66.2
+ # via datasets
+ # via deepspeed
+ # via huggingface-hub
+ # via mergoo
+ # via peft
+ # via transformers
+ # via weblinx
+ # via webnavix
+transformers==4.42.4
+ # via mergoo
+ # via peft
+ # via trl
+ # via webnavix
+trl==0.10.1
+ # via webnavix
+typing-extensions==4.12.2
+ # via huggingface-hub
+ # via mergoo
+ # via pydantic
+ # via pydantic-core
+ # via torch
+ # via tyro
+tyro==0.8.5
+ # via trl
+tzdata==2024.1
+ # via pandas
+urllib3==2.2.2
+ # via requests
+ # via sentry-sdk
+wandb==0.17.9
+ # via webnavix
+weblinx==0.3.0
+ # via webnavix
+xxhash==3.4.1
+ # via datasets
+yarl==1.9.4
+ # via aiohttp
diff --git a/requirements.lock b/requirements.lock
new file mode 100644
index 0000000..4f7883f
--- /dev/null
+++ b/requirements.lock
@@ -0,0 +1,262 @@
+# generated by rye
+# use `rye lock` or `rye sync` to update this lockfile
+#
+# last locked with the following flags:
+# pre: false
+# features: []
+# all-features: false
+# with-sources: false
+# generate-hashes: false
+# universal: false
+
+-e file:.
+accelerate==0.27.2
+ # via mergoo
+ # via peft
+ # via trl
+ # via webnavix
+aiohttp==3.9.5
+ # via datasets
+ # via fsspec
+aiosignal==1.3.1
+ # via aiohttp
+annotated-types==0.7.0
+ # via pydantic
+antlr4-python3-runtime==4.9.3
+ # via hydra-core
+ # via omegaconf
+attrs==23.2.0
+ # via aiohttp
+bitsandbytes==0.42.0
+ # via webnavix
+certifi==2024.7.4
+ # via requests
+ # via sentry-sdk
+charset-normalizer==3.3.2
+ # via requests
+click==8.1.7
+ # via wandb
+colorama==0.4.6
+ # via sacrebleu
+datasets==2.19.2
+ # via trl
+ # via webnavix
+deepspeed==0.15.1
+ # via webnavix
+dill==0.3.7
+ # via datasets
+ # via multiprocess
+docker-pycreds==0.4.0
+ # via wandb
+docstring-parser==0.16
+ # via tyro
+filelock==3.15.4
+ # via datasets
+ # via huggingface-hub
+ # via torch
+ # via transformers
+frozenlist==1.4.1
+ # via aiohttp
+ # via aiosignal
+fsspec==2023.10.0
+ # via datasets
+ # via huggingface-hub
+ # via torch
+gitdb==4.0.11
+ # via gitpython
+gitpython==3.1.43
+ # via wandb
+hjson==3.1.0
+ # via deepspeed
+huggingface-hub==0.24.6
+ # via accelerate
+ # via datasets
+ # via mergoo
+ # via peft
+ # via tokenizers
+ # via transformers
+ # via webnavix
+hydra-core==1.3.2
+ # via webnavix
+idna==3.7
+ # via requests
+ # via yarl
+jinja2==3.1.4
+ # via torch
+lxml==5.3.0
+ # via sacrebleu
+ # via webnavix
+markdown-it-py==3.0.0
+ # via rich
+markupsafe==2.1.5
+ # via jinja2
+mdurl==0.1.2
+ # via markdown-it-py
+mergoo==0.0.10
+ # via webnavix
+mpmath==1.3.0
+ # via sympy
+multidict==6.0.5
+ # via aiohttp
+ # via yarl
+multiprocess==0.70.15
+ # via datasets
+networkx==3.3
+ # via torch
+ninja==1.11.1.1
+ # via deepspeed
+numpy==1.26.4
+ # via accelerate
+ # via datasets
+ # via deepspeed
+ # via mergoo
+ # via pandas
+ # via peft
+ # via pyarrow
+ # via sacrebleu
+ # via scipy
+ # via transformers
+ # via trl
+ # via weblinx
+omegaconf==2.3.0
+ # via hydra-core
+ # via webnavix
+packaging==24.1
+ # via accelerate
+ # via datasets
+ # via deepspeed
+ # via huggingface-hub
+ # via hydra-core
+ # via peft
+ # via transformers
+pandas==2.2.2
+ # via datasets
+ # via weblinx
+peft==0.12.0
+ # via mergoo
+ # via webnavix
+platformdirs==4.2.2
+ # via wandb
+portalocker==2.10.1
+ # via sacrebleu
+protobuf==5.27.2
+ # via mergoo
+ # via wandb
+psutil==6.0.0
+ # via accelerate
+ # via deepspeed
+ # via peft
+ # via wandb
+py-cpuinfo==9.0.0
+ # via deepspeed
+pyarrow==17.0.0
+ # via datasets
+pyarrow-hotfix==0.6
+ # via datasets
+pydantic==2.9.2
+ # via deepspeed
+pydantic-core==2.23.4
+ # via pydantic
+pygments==2.18.0
+ # via rich
+python-dateutil==2.9.0.post0
+ # via pandas
+python-dotenv==1.0.1
+ # via webnavix
+pytz==2024.1
+ # via pandas
+pyyaml==6.0.1
+ # via accelerate
+ # via datasets
+ # via huggingface-hub
+ # via omegaconf
+ # via peft
+ # via transformers
+ # via wandb
+regex==2024.5.15
+ # via sacrebleu
+ # via transformers
+requests==2.32.3
+ # via datasets
+ # via fsspec
+ # via huggingface-hub
+ # via transformers
+ # via wandb
+rich==13.7.1
+ # via tyro
+sacrebleu==2.4.3
+ # via weblinx
+safetensors==0.4.3
+ # via accelerate
+ # via mergoo
+ # via peft
+ # via transformers
+scipy==1.14.0
+ # via bitsandbytes
+sentencepiece==0.2.0
+ # via mergoo
+sentry-sdk==2.13.0
+ # via wandb
+setproctitle==1.3.3
+ # via wandb
+setuptools==71.1.0
+ # via torch
+ # via wandb
+shtab==1.7.1
+ # via tyro
+six==1.16.0
+ # via docker-pycreds
+ # via python-dateutil
+smmap==5.0.1
+ # via gitdb
+sympy==1.13.1
+ # via torch
+tabulate==0.9.0
+ # via sacrebleu
+tokenizers==0.19.1
+ # via transformers
+torch==2.4.1
+ # via accelerate
+ # via deepspeed
+ # via mergoo
+ # via peft
+ # via trl
+ # via webnavix
+tqdm==4.66.2
+ # via datasets
+ # via deepspeed
+ # via huggingface-hub
+ # via mergoo
+ # via peft
+ # via transformers
+ # via weblinx
+ # via webnavix
+transformers==4.42.4
+ # via mergoo
+ # via peft
+ # via trl
+ # via webnavix
+trl==0.10.1
+ # via webnavix
+typing-extensions==4.12.2
+ # via huggingface-hub
+ # via mergoo
+ # via pydantic
+ # via pydantic-core
+ # via torch
+ # via tyro
+tyro==0.8.5
+ # via trl
+tzdata==2024.1
+ # via pandas
+urllib3==2.2.2
+ # via requests
+ # via sentry-sdk
+wandb==0.17.9
+ # via webnavix
+weblinx==0.3.0
+ # via webnavix
+xxhash==3.4.1
+ # via datasets
+yarl==1.9.4
+ # via aiohttp
diff --git a/ruff.toml b/ruff.toml
new file mode 100644
index 0000000..10c34d9
--- /dev/null
+++ b/ruff.toml
@@ -0,0 +1,48 @@
+target-version = "py312"
+indent-width = 4
+line-length = 120
+exclude = [
+ "__pycache__",
+ ".cache",
+ ".mypy_cache",
+ ".ruff_cache",
+ ".venv",
+ "*.egg-info",
+ "build",
+ "checkpoints",
+ "dist",
+ "out",
+ "results",
+ "wandb",
+ "wl_data",
+ "*log*",
+]
+
+[format]
+indent-style = "space"
+line-ending = "auto"
+quote-style = "double"
+skip-magic-trailing-comma = false
+
+
+[lint]
+select = ["ALL"]
+fixable = ["ALL"]
+ignore = [
+ "ANN001",
+ "EM101",
+ "ERA001",
+ "FBT001",
+ "FBT002",
+ "RET504",
+ "TRY002",
+ "TRY003",
+]
+
+[lint.extend-per-file-ignores]
+"__init__.py" = ["D1", "F403"]
+"build.py" = ["D1"]
+"dataset.py" = ["D1"]
+"eval.py" = ["D1"]
+"merge.py" = ["D1"]
+"train.py" = ["D1"]
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/dataset.py b/src/dataset.py
new file mode 100644
index 0000000..10e47c8
--- /dev/null
+++ b/src/dataset.py
@@ -0,0 +1,27 @@
+import time
+
+from huggingface_hub import snapshot_download
+
+attempt = 1
+while True:
+ try:
+ print(f"Attempt {attempt} for dataset 'McGill-NLP/WebLINX-full'.") # noqa: T201
+
+ snapshot_download(
+ repo_id="McGill-NLP/WebLINX-full",
+ repo_type="dataset",
+ local_dir="./wl_data",
+ ignore_patterns=["**/video.mp4"],
+ resume_download=True,
+ max_workers=16,
+ )
+
+ print("Download successful!") # noqa: T201
+ break
+
+ except Exception as e: # noqa: BLE001
+ print(f"Error occurred: {e}") # noqa: T201
+
+ print("Retrying in 5 seconds...") # noqa: T201
+ time.sleep(5)
+ attempt += 1
diff --git a/src/llama/__init__.py b/src/llama/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/llama/build.py b/src/llama/build.py
new file mode 100644
index 0000000..d1bcd2e
--- /dev/null
+++ b/src/llama/build.py
@@ -0,0 +1,95 @@
+import json
+import logging
+from functools import partial
+from pathlib import Path
+from typing import Any
+
+import huggingface_hub
+import hydra
+import wandb
+from dotenv import load_dotenv
+from omegaconf import OmegaConf
+from transformers import AutoTokenizer
+from weblinx import Demonstration
+from weblinx.processing import load_candidate_elements
+from weblinx.processing.prompt import (
+ build_input_records_from_selected_turns,
+ select_turns_and_candidates_for_prompts,
+)
+from weblinx.utils import (
+ load_demo_names_in_split,
+ set_seed,
+)
+
+from .processing import (
+ build_formatter_for_multichoice,
+ build_prompt_records_for_llama_truncated,
+ insert_formatted_chat_into_records,
+)
+
+load_dotenv()
+
+
+@hydra.main(config_path="conf", config_name="config", version_base=None)
+def main(cfg) -> None:
+ logging.basicConfig(level=logging.INFO)
+ logger = logging.getLogger(__name__)
+
+ logger.info(OmegaConf.to_yaml(cfg))
+
+ huggingface_hub.login(token=cfg.huggingface.token)
+ wandb.login(key=cfg.wandb.key)
+
+ set_seed(cfg.seed)
+
+ model_save_dir = Path(cfg.model.save_dir).expanduser()
+ model_save_dir.mkdir(exist_ok=True, parents=True)
+ input_records_save_dir = Path(model_save_dir.joinpath(f"{cfg.build.split}")).expanduser()
+ input_records_save_dir.mkdir(exist_ok=True, parents=True)
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ cfg.model.base_name,
+ padding_side="right",
+ trust_remote_code=True,
+ )
+ tokenizer.pad_token = tokenizer.unk_token
+
+ split_path = Path(cfg.data.split_path).expanduser()
+
+ demo_names: list[str] = load_demo_names_in_split(split_path, split=cfg.build.split)
+ demos = [Demonstration(demo_name, base_dir=cfg.data.base_dir) for demo_name in demo_names]
+ candidates = load_candidate_elements(path=cfg.candidates.build_path)
+
+ format_intent = build_formatter_for_multichoice()
+ build_prompt_records_fn = partial(
+ build_prompt_records_for_llama_truncated,
+ format_intent=format_intent,
+ tokenizer=tokenizer, # type: ignore # noqa: PGH003
+ )
+
+ selected_turns: list[dict[str, Any]] = select_turns_and_candidates_for_prompts(
+ demos=demos,
+ candidates=candidates,
+ num_candidates=cfg.candidates.k,
+ )
+
+ input_records = build_input_records_from_selected_turns(
+ selected_turns=selected_turns,
+ format_intent=format_intent,
+ build_prompt_records_fn=build_prompt_records_fn,
+ format_prompt_records_fn=None,
+ )
+
+ input_records = insert_formatted_chat_into_records(
+ input_records,
+ demos,
+ tokenizer,
+ include_output_target=cfg.build.include_output_target,
+ )
+
+ with Path.open(input_records_save_dir.joinpath("input_records.json"), "w") as f:
+ json.dump(input_records, f, indent=2)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/llama/conf/accelerate/deepspeed.yaml b/src/llama/conf/accelerate/deepspeed.yaml
new file mode 100644
index 0000000..c4882ae
--- /dev/null
+++ b/src/llama/conf/accelerate/deepspeed.yaml
@@ -0,0 +1,13 @@
+compute_environment: LOCAL_MACHINE
+distributed_type: DEEPSPEED
+deepspeed_config:
+ deepspeed_config_file: ../deepspeed/zero2.json
+ # deepspeed_moe_layer_cls_names: MoeLayer
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 1
+rdzv_backend: static
+same_network: true
+use_cpu: false
diff --git a/src/llama/conf/accelerate/fsdp.yaml b/src/llama/conf/accelerate/fsdp.yaml
new file mode 100644
index 0000000..743503e
--- /dev/null
+++ b/src/llama/conf/accelerate/fsdp.yaml
@@ -0,0 +1,21 @@
+compute_environment: LOCAL_MACHINE
+distributed_type: FSDP
+fsdp_config:
+ fsdp_sharding_strategy: FULL_SHARD
+ fsdp_offload_params: false
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
+ fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
+ fsdp_backward_prefetch_policy: BACKWARD_PRE
+ fsdp_forward_prefetch: false
+ fsdp_state_dict_type: SHARDED_STATE_DICT
+ fsdp_use_orig_params: false
+ fsdp_cpu_ram_efficient_loading: true
+ fsdp_sync_module_states: true
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 2
+rdzv_backend: static
+same_network: true
+use_cpu: false
diff --git a/src/llama/conf/deepspeed/zero2.json b/src/llama/conf/deepspeed/zero2.json
new file mode 100644
index 0000000..5f0ac20
--- /dev/null
+++ b/src/llama/conf/deepspeed/zero2.json
@@ -0,0 +1,14 @@
+{
+ "train_batch_size": "auto",
+ "gradient_accumulation_steps": "auto",
+ "optimizer": "auto",
+ "scheduler": "auto",
+ "bf16": "auto",
+ "zero_optimization": {
+ "stage": 2,
+ "contiguous_gradients": true,
+ "offload_optimizer": {
+ "device": "cpu"
+ }
+ }
+}
diff --git a/src/llama/conf/variant/base.yaml b/src/llama/conf/variant/base.yaml
new file mode 100644
index 0000000..952edc1
--- /dev/null
+++ b/src/llama/conf/variant/base.yaml
@@ -0,0 +1,117 @@
+project_name: webnavix
+project_dir: ${oc.env:PROJECT_DIR}
+seed: 123
+
+data:
+ num_proc: 8
+ split_path: ${project_dir}/wl_data/splits.json
+ base_dir: ${project_dir}/wl_data/demonstrations/
+
+candidates:
+ k: 10
+ build_path: ${project_dir}/wl_data/candidates/${build.split}.jsonl
+ train_path: ${project_dir}/wl_data/candidates/${train.split}.jsonl
+ eval_path: ${project_dir}/wl_data/candidates/${eval.split}.jsonl
+
+model:
+ name: nitic-nlp-team/webnavix-llama-base
+ base_name: princeton-nlp/Sheared-LLaMA-2.7B
+ save_dir: ${project_dir}/checkpoints/${project_name}/${model.name}
+ max_inp_len: null
+ max_out_len: 256
+ use_rope: True
+ use_flash_attention_2: True
+ moe: False
+ freeze:
+ use: False
+ trainable_layers:
+ - gate_proj
+ - up_proj
+ - down_proj
+
+build:
+ split: train
+ include_output_target: True
+
+train:
+ split: train
+ domain: False
+ num_epochs: 3
+ learning_rate: 5e-5
+ batch_size_per_device: 16
+ gradient_accumulation_steps: 1
+ gradient_checkpointing: True
+ max_grad_norm: 1.0
+ optim: adamw_torch
+ weight_decay: 0.0
+ scheduler: linear
+ warmup_steps: 0
+ warmup_ratio: 0.0
+ accelerate:
+ use: False
+ qlora:
+ use: False
+ r: 256
+ alpha: 256
+ dropout: 0.05
+ bias: none
+ target_modules:
+ - embed_tokens
+ - q_proj
+ - k_proj
+ - v_proj
+ - o_proj
+ - gate_proj
+ - up_proj
+ - down_proj
+ - lm_head
+
+merge:
+ num_experts_per_tok: 2
+ experts:
+ - expert_name: ai-tools-expert
+ model_id: nitic-nlp-team/webnavix-llama-ai-tools
+ - expert_name: booking-expert
+ model_id: nitic-nlp-team/webnavix-llama-booking
+ - expert_name: composing-expert
+ model_id: nitic-nlp-team/webnavix-llama-composing
+ - expert_name: information-lookup-expert
+ model_id: nitic-nlp-team/webnavix-llama-information-lookup
+ - expert_name: shopping-expert
+ model_id: nitic-nlp-team/webnavix-llama-shopping
+ - expert_name: social-interaction-expert
+ model_id: nitic-nlp-team/webnavix-llama-social-interaction
+ - expert_name: summarizing-expert
+ model_id: nitic-nlp-team/webnavix-llama-summarizing
+ - expert_name: task-management-expert
+ model_id: nitic-nlp-team/webnavix-llama-task-management
+ - expert_name: shared-expert
+ model_id: nitic-nlp-team/webnavix-llama-shared
+ router_layers:
+ - gate_proj
+ - up_proj
+ - down_proj
+
+eval:
+ split: valid
+ domain: False
+ batch_size_per_device: 16
+ gradient_accumulation_steps: 1
+ result_dir: ${project_dir}/results/${project_name}/${eval.split}/${model.name}
+ load_from_save_dir: True
+
+huggingface:
+ token: ${oc.env:HF_TOKEN}
+
+wandb:
+ project: ${oc.env:WANDB_PROJECT}
+ key: ${oc.env:WANDB_API_KEY}
+
+hydra:
+ run:
+ dir: ${project_dir}/logs/${project_name}/${hydra.job.name}/${now:%Y-%m-%d-%H:%M:%S}
+ sweep:
+ dir: ${hydra.run.dir}
+ job:
+ chdir: False
+ verbose: INFO
diff --git a/src/llama/conf/variant/expert.yaml b/src/llama/conf/variant/expert.yaml
new file mode 100644
index 0000000..5b68de5
--- /dev/null
+++ b/src/llama/conf/variant/expert.yaml
@@ -0,0 +1,117 @@
+project_name: webnavix
+project_dir: ${oc.env:PROJECT_DIR}
+seed: 123
+
+data:
+ num_proc: 8
+ split_path: ${project_dir}/wl_data/splits.json
+ base_dir: ${project_dir}/wl_data/demonstrations/
+
+candidates:
+ k: 10
+ build_path: ${project_dir}/wl_data/candidates/${build.split}.jsonl
+ train_path: ${project_dir}/wl_data/candidates/${train.split}.jsonl
+ eval_path: ${project_dir}/wl_data/candidates/${eval.split}.jsonl
+
+model:
+ name: nitic-nlp-team/webnavix-ai-tools
+ base_name: nitic-nlp-team/webnavix-llama-base
+ save_dir: ${project_dir}/checkpoints/${project_name}/${model.name}
+ max_inp_len: null
+ max_out_len: 256
+ use_rope: True
+ use_flash_attention_2: True
+ moe: False
+ freeze:
+ use: True
+ trainable_layers:
+ - gate_proj
+ - up_proj
+ - down_proj
+
+build:
+ split: train
+ include_output_target: True
+
+train:
+ split: train
+ domain: AI_Tools
+ num_epochs: 3
+ learning_rate: 5e-5
+ batch_size_per_device: 16
+ gradient_accumulation_steps: 1
+ gradient_checkpointing: True
+ max_grad_norm: 1.0
+ optim: adamw_torch
+ weight_decay: 0.0
+ scheduler: linear
+ warmup_steps: 0
+ warmup_ratio: 0.0
+ accelerate:
+ use: False
+ qlora:
+ use: False
+ r: 256
+ alpha: 256
+ dropout: 0.05
+ bias: none
+ target_modules:
+ - embed_tokens
+ - q_proj
+ - k_proj
+ - v_proj
+ - o_proj
+ - gate_proj
+ - up_proj
+ - down_proj
+ - lm_head
+
+merge:
+ num_experts_per_tok: 2
+ experts:
+ - expert_name: ai-tools-expert
+ model_id: nitic-nlp-team/webnavix-llama-ai-tools
+ - expert_name: booking-expert
+ model_id: nitic-nlp-team/webnavix-llama-booking
+ - expert_name: composing-expert
+ model_id: nitic-nlp-team/webnavix-llama-composing
+ - expert_name: information-lookup-expert
+ model_id: nitic-nlp-team/webnavix-llama-information-lookup
+ - expert_name: shopping-expert
+ model_id: nitic-nlp-team/webnavix-llama-shopping
+ - expert_name: social-interaction-expert
+ model_id: nitic-nlp-team/webnavix-llama-social-interaction
+ - expert_name: summarizing-expert
+ model_id: nitic-nlp-team/webnavix-llama-summarizing
+ - expert_name: task-management-expert
+ model_id: nitic-nlp-team/webnavix-llama-task-management
+ - expert_name: shared-expert
+ model_id: nitic-nlp-team/webnavix-llama-shared
+ router_layers:
+ - gate_proj
+ - up_proj
+ - down_proj
+
+eval:
+ split: valid
+ domain: AI_Tools
+ batch_size_per_device: 16
+ gradient_accumulation_steps: 1
+ result_dir: ${project_dir}/results/${project_name}/${eval.split}/${model.name}
+ load_from_save_dir: True
+
+huggingface:
+ token: ${oc.env:HF_TOKEN}
+
+wandb:
+ project: ${oc.env:WANDB_PROJECT}
+ key: ${oc.env:WANDB_API_KEY}
+
+hydra:
+ run:
+ dir: ${project_dir}/logs/${project_name}/${hydra.job.name}/${now:%Y-%m-%d-%H:%M:%S}
+ sweep:
+ dir: ${hydra.run.dir}
+ job:
+ chdir: False
+ verbose: INFO
diff --git a/src/llama/conf/variant/merge.yaml b/src/llama/conf/variant/merge.yaml
new file mode 100644
index 0000000..23f0927
--- /dev/null
+++ b/src/llama/conf/variant/merge.yaml
@@ -0,0 +1,117 @@
+project_name: webnavix
+project_dir: ${oc.env:PROJECT_DIR}
+seed: 123
+
+data:
+ num_proc: 8
+ split_path: ${project_dir}/wl_data/splits.json
+ base_dir: ${project_dir}/wl_data/demonstrations/
+
+candidates:
+ k: 10
+ build_path: ${project_dir}/wl_data/candidates/${build.split}.jsonl
+ train_path: ${project_dir}/wl_data/candidates/${train.split}.jsonl
+ eval_path: ${project_dir}/wl_data/candidates/${eval.split}.jsonl
+
+model:
+ name: nitic-nlp-team/webnavix-llama-merged
+ base_name: nitic-nlp-team/webnavix-llama-base
+ save_dir: ${project_dir}/checkpoints/${project_name}/${model.name}
+ max_inp_len: null
+ max_out_len: 256
+ use_rope: True
+ use_flash_attention_2: True
+ moe: True
+ freeze:
+ use: False
+ trainable_layers:
+ - gate_proj
+ - up_proj
+ - down_proj
+
+build:
+ split: train
+ include_output_target: True
+
+train:
+ split: train
+ domain: False
+ num_epochs: 3
+ learning_rate: 5e-5
+ batch_size_per_device: 16
+ gradient_accumulation_steps: 1
+ gradient_checkpointing: True
+ max_grad_norm: 1.0
+ optim: adamw_torch
+ weight_decay: 0.0
+ scheduler: linear
+ warmup_steps: 0
+ warmup_ratio: 0.0
+ accelerate:
+ use: False
+ qlora:
+ use: False
+ r: 256
+ alpha: 256
+ dropout: 0.05
+ bias: none
+ target_modules:
+ - embed_tokens
+ - q_proj
+ - k_proj
+ - v_proj
+ - o_proj
+ - gate_proj
+ - up_proj
+ - down_proj
+ - lm_head
+
+merge:
+ num_experts_per_tok: 2
+ experts:
+ - expert_name: ai-tools-expert
+ model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-ai-tools/checkpoint-500/
+ - expert_name: booking-expert
+ model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-booking/checkpoint-1050/
+ - expert_name: composing-expert
+ model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-composing/checkpoint-500/
+ - expert_name: information-lookup-expert
+ model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-information-lookup/checkpoint-500/
+ - expert_name: shopping-expert
+ model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-shopping/checkpoint-400/
+ - expert_name: social-interaction-expert
+ model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-social-interaction/checkpoint-150/
+ - expert_name: summarizing-expert
+ model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-summarizing/checkpoint-450/
+ - expert_name: task-management-expert
+ model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-task-management/checkpoint-450/
+ - expert_name: shared-expert
+ model_id: /content/drive/MyDrive/Projects/nitic-nlp-team/webnavix/checkpoints/webnavix/nitic-nlp-team/webnavix-llama-shared/checkpoint-1800/
+ router_layers:
+ - gate_proj
+ - up_proj
+ - down_proj
+
+eval:
+ split: valid
+ domain: False
+ batch_size_per_device: 16
+ gradient_accumulation_steps: 1
+ result_dir: ${project_dir}/results/${project_name}/${eval.split}/${model.name}
+ load_from_save_dir: True
+
+huggingface:
+ token: ${oc.env:HF_TOKEN}
+
+wandb:
+ project: ${oc.env:WANDB_PROJECT}
+ key: ${oc.env:WANDB_API_KEY}
+
+hydra:
+ run:
+ dir: ${project_dir}/logs/${project_name}/${hydra.job.name}/${now:%Y-%m-%d-%H:%M:%S}
+ sweep:
+ dir: ${hydra.run.dir}
+ job:
+ chdir: False
+ verbose: INFO
diff --git a/src/llama/conf/variant/moe.yaml b/src/llama/conf/variant/moe.yaml
new file mode 100644
index 0000000..aab8b00
--- /dev/null
+++ b/src/llama/conf/variant/moe.yaml
@@ -0,0 +1,107 @@
+project_name: webnavix
+project_dir: ${oc.env:PROJECT_DIR}
+seed: 123
+
+data:
+ num_proc: 8
+ split_path: ${project_dir}/wl_data/splits.json
+ base_dir: ${project_dir}/wl_data/demonstrations/
+
+candidates:
+ k: 10
+ build_path: ${project_dir}/wl_data/candidates/${build.split}.jsonl
+ train_path: ${project_dir}/wl_data/candidates/${train.split}.jsonl
+ eval_path: ${project_dir}/wl_data/candidates/${eval.split}.jsonl
+
+model:
+ name: nitic-nlp-team/webnavix-llama
+ base_name: nitic-nlp-team/webnavix-llama-merged
+ save_dir: ${project_dir}/checkpoints/${project_name}/${model.name}
+ max_inp_len: null
+ max_out_len: 256
+ use_rope: True
+ use_flash_attention_2: True
+ moe: True
+ freeze:
+ use: False
+ trainable_layers:
+ - gate
+
+build:
+ split: train
+ include_output_target: True
+
+train:
+ split: train
+ domain: False
+ num_epochs: 3
+ learning_rate: 5e-5
+ batch_size_per_device: 4
+ gradient_accumulation_steps: 1
+ gradient_checkpointing: True
+ max_grad_norm: 1.0
+ optim: adamw_torch
+ weight_decay: 0.0
+ scheduler: linear
+ warmup_steps: 0
+ warmup_ratio: 0.0
+ accelerate:
+ use: True
+ qlora:
+ use: False
+ r: 256
+ alpha: 256
+ dropout: 0.05
+ bias: none
+ target_modules:
+ - gate
+
+merge:
+ num_experts_per_tok: 2
+ experts:
+ - expert_name: ai-tools-expert
+ model_id: nitic-nlp-team/webnavix-llama-ai-tools
+ - expert_name: booking-expert
+ model_id: nitic-nlp-team/webnavix-llama-booking
+ - expert_name: composing-expert
+ model_id: nitic-nlp-team/webnavix-llama-composing
+ - expert_name: information-lookup-expert
+ model_id: nitic-nlp-team/webnavix-llama-information-lookup
+ - expert_name: shopping-expert
+ model_id: nitic-nlp-team/webnavix-llama-shopping
+ - expert_name: social-interaction-expert
+ model_id: nitic-nlp-team/webnavix-llama-social-interaction
+ - expert_name: summarizing-expert
+ model_id: nitic-nlp-team/webnavix-llama-summarizing
+ - expert_name: task-management-expert
+ model_id: nitic-nlp-team/webnavix-llama-task-management
+ - expert_name: shared-expert
+ model_id: nitic-nlp-team/webnavix-llama-shared
+ router_layers:
+ - gate_proj
+ - up_proj
+ - down_proj
+
+eval:
+ split: valid
+ domain: False
+ batch_size_per_device: 4
+ gradient_accumulation_steps: 1
+ result_dir: ${project_dir}/results/${project_name}/${eval.split}/${model.name}
+ load_from_save_dir: True
+
+huggingface:
+ token: ${oc.env:HF_TOKEN}
+
+wandb:
+ project: ${oc.env:WANDB_PROJECT}
+ key: ${oc.env:WANDB_API_KEY}
+
+hydra:
+ run:
+ dir: ${project_dir}/logs/${project_name}/${hydra.job.name}/${now:%Y-%m-%d-%H:%M:%S}
+ sweep:
+ dir: ${hydra.run.dir}
+ job:
+ chdir: False
+ verbose: INFO
diff --git a/src/llama/eval.py b/src/llama/eval.py
new file mode 100644
index 0000000..565af51
--- /dev/null
+++ b/src/llama/eval.py
@@ -0,0 +1,142 @@
+import json
+import logging
+from pathlib import Path
+from typing import Any
+
+import huggingface_hub
+import hydra
+import torch
+import wandb
+from accelerate import Accelerator
+from dotenv import load_dotenv
+from mergoo.models.modeling_llama import LlamaForCausalLM
+from omegaconf import OmegaConf
+from tqdm import tqdm
+from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ BitsAndBytesConfig, # type: ignore # noqa: PGH003
+ PreTrainedTokenizer,
+ PreTrainedTokenizerFast,
+ pipeline, # type: ignore # noqa: PGH003
+)
+from transformers.pipelines.pt_utils import KeyDataset
+from weblinx.utils import set_seed
+
+load_dotenv()
+
+
+@hydra.main(version_base=None, config_path="conf", config_name="config")
+def main(cfg) -> None:
+ logging.basicConfig(level=logging.INFO)
+ logger = logging.getLogger(__name__)
+
+ logger.info(OmegaConf.to_yaml(cfg))
+
+ huggingface_hub.login(token=cfg.huggingface.token)
+ wandb.login(key=cfg.wandb.key)
+
+ set_seed(cfg.seed)
+
+ model_save_dir = Path(cfg.model.save_dir).expanduser()
+ model_save_dir.mkdir(exist_ok=True, parents=True)
+ result_dir = Path(cfg.eval.result_dir).expanduser()
+ result_dir.mkdir(parents=True, exist_ok=True)
+
+ bnb_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.bfloat16,
+ bnb_4bit_use_double_quant=True,
+ )
+
+ load_model_name = str(model_save_dir) if cfg.eval.get("load_from_save_dir", False) is True else cfg.model.base_name
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ load_model_name,
+ padding_side="left",
+ trust_remote_code=True,
+ )
+ tokenizer.pad_token = tokenizer.unk_token
+
+ accelerator = Accelerator() if cfg.train.accelerate.use else None
+ model = (
+ LlamaForCausalLM.from_pretrained(
+ cfg.model.base_name,
+ device_map={"": accelerator.process_index} if accelerator is not None else "auto",
+ torch_dtype=torch.bfloat16,
+ rope_scaling={"type": "dynamic", "factor": 2.0} if cfg.model.use_rope else None,
+ attn_implementation="flash_attention_2" if cfg.model.use_flash_attention_2 else None,
+ quantization_config=bnb_config if cfg.train.qlora.use else None,
+ )
+ if cfg.model.moe
+ else AutoModelForCausalLM.from_pretrained(
+ cfg.model.base_name,
+ device_map={"": accelerator.process_index} if accelerator is not None else "auto",
+ torch_dtype=torch.bfloat16,
+ rope_scaling={"type": "dynamic", "factor": 2.0} if cfg.model.use_rope else None,
+ attn_implementation="flash_attention_2" if cfg.model.use_flash_attention_2 else None,
+ quantization_config=bnb_config if cfg.train.qlora.use else None,
+ )
+ )
+
+ with Path.open(
+ model_save_dir.joinpath(
+ f"{cfg.eval.split}/{cfg.eval.domain if cfg.eval.domain else ''}/input_records.json",
+ ),
+ "r",
+ ) as f:
+ input_records = json.load(f)
+
+ evaluate(cfg, model, tokenizer, input_records, result_dir)
+
+
+def evaluate(
+ cfg,
+ model,
+ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
+ input_records: list[dict[str, Any]],
+ result_dir: Path,
+) -> None:
+ key_dataset = KeyDataset(input_records, key="text") # type: ignore # noqa: PGH003
+ pipe = pipeline(
+ "text-generation",
+ model=model,
+ tokenizer=tokenizer,
+ torch_dtype=torch.bfloat16,
+ )
+
+ results = []
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16): # type: ignore # noqa: PGH003
+ pbar = tqdm(
+ pipe(
+ key_dataset,
+ batch_size=cfg.eval.batch_size_per_device,
+ pad_token_id=tokenizer.unk_token_id,
+ max_new_tokens=cfg.model.max_out_len,
+ return_full_text=False,
+ ),
+ desc="Generating outputs",
+ total=len(key_dataset),
+ )
+ for i, out in enumerate(pbar):
+ input_record = input_records[i]
+ generated_text = out[0]["generated_text"]
+ result = {
+ "demo_name": input_record["demo_name"],
+ "turn_index": input_record["turn_index"],
+ "prompt": input_record["prompt"],
+ "text": input_record["text"],
+ "output_predicted": generated_text,
+ "output_target": input_record["output_target"],
+ "output_target_dict": input_record["output_target_dict"],
+ }
+
+ results.append(result)
+
+ with Path.open(result_dir.joinpath("results.json"), "w") as f:
+ json.dump(results, f, indent=2)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/llama/merge.py b/src/llama/merge.py
new file mode 100644
index 0000000..000d3b9
--- /dev/null
+++ b/src/llama/merge.py
@@ -0,0 +1,38 @@
+import logging
+from pathlib import Path
+
+import huggingface_hub
+import hydra
+import torch
+from mergoo.compose_experts import ComposeExperts
+from omegaconf import OmegaConf
+from weblinx.utils import set_seed
+
+
+@hydra.main(config_path="conf", config_name="config", version_base=None)
+def main(cfg) -> None:
+ logging.basicConfig(level=logging.INFO)
+ logger = logging.getLogger(__name__)
+
+ logger.info(OmegaConf.to_yaml(cfg))
+
+ huggingface_hub.login(token=cfg.huggingface.token)
+
+ set_seed(cfg.seed)
+
+ model_save_dir = Path(cfg.model.save_dir).expanduser()
+ model_save_dir.mkdir(exist_ok=True, parents=True)
+
+ merge_config = {
+ "model_type": "llama",
+ "num_experts_per_tok": OmegaConf.to_container(cfg.merge.num_experts_per_tok, resolve=True), # type: ignore # noqa: PGH003
+ "experts": OmegaConf.to_container(cfg.merge.experts, resolve=True), # type: ignore # noqa: PGH003
+ "router_layers": OmegaConf.to_container(cfg.merge.router_layers, resolve=True), # type: ignore # noqa: PGH003
+ }
+ merger = ComposeExperts(merge_config, torch_dtype=torch.bfloat16)
+ merger.compose()
+ merger.save_checkpoint(model_save_dir)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/llama/processing.py b/src/llama/processing.py
new file mode 100644
index 0000000..5f1986f
--- /dev/null
+++ b/src/llama/processing.py
@@ -0,0 +1,397 @@
+"""The processing module contains functions to format the dataset and build input records."""
+
+from collections.abc import Callable
+from copy import deepcopy
+from functools import partial
+
+import lxml.html
+import weblinx.utils.format as wlf
+from transformers import (
+ PreTrainedTokenizer,
+ PreTrainedTokenizerFast,
+)
+from weblinx import (
+ Demonstration,
+ Replay,
+ Turn,
+)
+from weblinx.processing.dom import clean_and_prune_tree
+from weblinx.processing.prompt import (
+ find_turns_with_instructor_chat,
+ format_utterances,
+ get_speaker,
+ multi_attempt_format_prev_turns_truncated,
+)
+from weblinx.processing.truncation import (
+ multi_attempt_truncate_cands_turn,
+ multi_attempt_truncate_dom_tree,
+ reduce_list_of_lengths,
+ truncate_text_at_center,
+)
+from weblinx.utils.recs import get_list_from_records_by_key
+
+
+def build_formatter_for_multichoice() -> Callable: # noqa: D103
+ format_click = partial(wlf.format_click, formatters=(wlf.format_uid,))
+ format_text_input = partial(
+ wlf.format_text_input,
+ formatters=(
+ partial(wlf.format_arg_item, name="text", max_length=200),
+ wlf.format_uid,
+ ),
+ )
+ format_change = partial(
+ wlf.format_change,
+ formatters=(
+ partial(wlf.format_arg_item, name="value", max_length=200),
+ wlf.format_uid,
+ ),
+ )
+ format_submit = partial(wlf.format_submit, formatters=(wlf.format_uid,))
+ format_load = partial(
+ wlf.format_load,
+ include_transition=False,
+ include_timestamp=False,
+ max_length=200,
+ )
+ format_scroll = partial(wlf.format_scroll, include_timestamp=False)
+
+ format_say = partial(wlf.format_say, include_timestamp=False)
+
+ format_intent_auto = partial(
+ wlf.format_intent_automatically,
+ format_change=format_change,
+ format_click=format_click,
+ format_load=format_load,
+ format_say=format_say,
+ format_scroll=format_scroll,
+ format_submit=format_submit,
+ format_text_input=format_text_input,
+ )
+
+ return format_intent_auto
+
+
+def __get_system_prompt_template_for_llama_mc_concise() -> str:
+ sys_prompt_template = (
+ "You are an AI assistant with a deep understanding of HTML "
+ "and you must predict actions based on a user request, which will be executed. "
+ "Use one of the following, replacing [] with an appropriate value: "
+ "change(value=[str], uid=[str]) ; "
+ "click(uid=[str]) ; "
+ "load(url=[str]) ; "
+ 'say(speaker="navigator", utterance=[str]) ; '
+ "scroll(x=[int], y=[int]) ; "
+ "submit(uid=[str]) ;"
+ "text_input(text=[str], uid=[str]) ;\n"
+ "The user's first and last {num_utterances} utterances are: "
+ "{utterance_context} ;\n"
+ "Viewport size: {height}h x {width}w ;\n"
+ "Only the last {num_prev_turns} turns are provided."
+ )
+
+ return sys_prompt_template
+
+
+def __get_candidate_prompt_template_for_llama() -> str:
+ return "Here are the top candidates for this turn:\n{candidate_str}"
+
+
+def __get_final_user_message() -> str:
+ return (
+ "Please select the best action using the correct format, "
+ "do not provide any other information or explanation."
+ )
+
+
+def __merge_prev_turns(prev_turns_text_list: list[str], final_user_message: str) -> list[dict[str, str]]:
+ prev_turns_merged: list[dict[str, str]] = []
+
+ for i, turn_text in enumerate(prev_turns_text_list):
+ role = get_speaker(
+ turn_text,
+ instructor_name="user",
+ navigator_name="assistant",
+ default_name="unknown",
+ )
+
+ if i > 0 and prev_turns_merged[-1]["role"] == role:
+ prev_turns_merged[-1]["content"] += " " + turn_text
+ else:
+ prev_turns_merged.append({"role": role, "content": turn_text})
+
+ if len(prev_turns_merged) > 0 and prev_turns_merged[-1]["role"] == "user":
+ prev_turns_merged[-1]["content"] += " " + final_user_message
+ else:
+ prev_turns_merged.append({"role": "user", "content": final_user_message})
+
+ return prev_turns_merged
+
+
+def __format_utterances_truncated( # noqa: ANN202, PLR0913
+ turns: list["Turn"],
+ tokenizer: "PreTrainedTokenizer",
+ max_tokens: int,
+ format_utterances_fn,
+ num_utterances: int = 5,
+ type_filter="chat",
+ sep=" ",
+ convert_to_minutes=True,
+ template="[{timestamp}] {utterance}",
+ allow_iterative_reduction=False,
+):
+ utterances = format_utterances_fn(
+ turns,
+ num_utterances=num_utterances,
+ type_filter=type_filter,
+ sep=None,
+ convert_to_minutes=convert_to_minutes,
+ template=template,
+ )
+ if isinstance(utterances, str):
+ utterances = [utterances]
+
+ utterances_str = " ".join(utterances) if sep is None else str(sep).join(utterances)
+ utter_tokens = tokenizer.tokenize(utterances_str, add_special_tokens=False)
+ num_tokens_to_remove = len(utter_tokens) - max_tokens
+
+ records = []
+ for i, text in enumerate(utterances):
+ tokens = tokenizer.tokenize(text, add_special_tokens=False)
+ records.append(
+ {
+ "index": i,
+ "text": text,
+ "tokens": tokens,
+ "length": len(tokens),
+ },
+ )
+
+ # NOTE: We only count the token lengths of the values, not the entire formatted string.
+ # The full string may have additional tokens. (key, separator, etc.)
+ # Consequently, max_total_length is different from max_tokens.
+ records = sorted(records, key=lambda r: r["length"])
+ lengths_orig = get_list_from_records_by_key(records, "length") # type: ignore # noqa: PGH003
+ max_total_length = sum(lengths_orig) - num_tokens_to_remove
+ lengths_reduced = reduce_list_of_lengths(lengths_orig, max_length=max_total_length)
+
+ for i, rec in enumerate(records):
+ red_length = lengths_reduced[i]
+
+ # NOTE: If the length is the same, then we don't need to do anything.
+ # Otherwise, we need to truncate the text.
+ if red_length >= rec["length"]:
+ continue
+
+ trunc = truncate_text_at_center(
+ rec["text"],
+ tokenizer=tokenizer,
+ max_tokens=red_length,
+ allow_iterative_reduction=allow_iterative_reduction,
+ )
+
+ utterances[rec["index"]] = trunc["text"]
+
+ if sep is None:
+ return utterances
+
+ return sep.join(utterances)
+
+
+def __format_candidates(candidates, max_char_len=300, use_uid_as_rank=False): # noqa: ANN202
+ s = ""
+ for cand in candidates:
+ doc = cand["doc"].replace("\n", " ").rstrip()
+ rank = "uid = " + cand["uid"] if use_uid_as_rank else cand["rank"]
+
+ if max_char_len is not None and len(doc) > max_char_len:
+ doc = doc[: max_char_len - 3] + "..."
+
+ s += f"({rank}) {doc}\n"
+
+ return s
+
+
+def build_prompt_records_for_llama_truncated( # noqa: D103, PLR0913
+ replay: Replay,
+ turn: Turn,
+ format_intent,
+ tokenizer: PreTrainedTokenizer,
+ cands_turn=None,
+ num_utterances: int = 5,
+ num_prev_turns: int = 5,
+ system_prompt_template=None,
+ candidate_prompt_template=None,
+ final_user_message=None,
+ include_html=True,
+ format_candidates_fn=partial( # noqa: B008
+ __format_candidates,
+ max_char_len=None, # type: ignore # noqa: PGH003
+ use_uid_as_rank=True,
+ ),
+ merge_prev_turns_fn=__merge_prev_turns,
+ format_output_dict_fn: Callable = partial( # noqa: B008
+ wlf.format_output_dictionary,
+ function_key="intent",
+ ),
+ max_html_tokens: int = 700,
+ max_utterance_tokens: int = 40 * 5,
+ max_prev_turns_tokens: int = 50 * 5,
+ max_candidates_tokens: int = 65 * 10,
+ add_unused_len_to_cands: bool = True,
+ allow_iterative_reduction: bool = False,
+ use_tokenizer_template: bool = False,
+ template_tokenizer=None,
+ parser=None,
+) -> list[dict[str, str]]:
+ if system_prompt_template is None:
+ system_prompt_template = __get_system_prompt_template_for_llama_mc_concise()
+
+ if candidate_prompt_template is None:
+ candidate_prompt_template = __get_candidate_prompt_template_for_llama()
+
+ if final_user_message is None:
+ final_user_message = __get_final_user_message()
+
+ instructor_chat_turns = find_turns_with_instructor_chat(
+ replay,
+ turn,
+ num_prev_turns=num_prev_turns,
+ )
+ utterance_context = __format_utterances_truncated(
+ instructor_chat_turns,
+ tokenizer=tokenizer,
+ max_tokens=max_utterance_tokens,
+ num_utterances=num_utterances,
+ format_utterances_fn=format_utterances,
+ allow_iterative_reduction=allow_iterative_reduction,
+ )
+
+ prev_turns_text_list = multi_attempt_format_prev_turns_truncated(
+ replay=replay,
+ turn=turn,
+ format_intent=partial(format_intent, return_as=dict),
+ tokenizer=tokenizer,
+ num_prev_turns=num_prev_turns,
+ turn_sep=None, # type: ignore # noqa: PGH003
+ max_tokens=max_prev_turns_tokens,
+ max_attempts=5,
+ format_output_dict_fn=format_output_dict_fn,
+ warn_after_attempts=False,
+ allow_iterative_reduction=allow_iterative_reduction,
+ )
+
+ prev_turns_merged = merge_prev_turns_fn(
+ prev_turns_text_list=prev_turns_text_list,
+ final_user_message=final_user_message,
+ )
+
+ sys_prompt = system_prompt_template.format(
+ num_utterances=num_utterances - 1, # NOTE: 1 less since we add the first utterance.
+ utterance_context=utterance_context,
+ height=turn.viewport_height,
+ width=turn.viewport_width,
+ num_prev_turns=num_prev_turns,
+ )
+
+ if include_html and turn.html not in ["", None] and cands_turn is not None:
+ dom_tree_raw = lxml.html.fromstring(turn.html, parser=parser)
+ dom_tree_pruned = clean_and_prune_tree(dom_tree_raw, cands_turn=cands_turn)
+ trunc = multi_attempt_truncate_dom_tree(
+ dom_tree=dom_tree_pruned,
+ tokenizer=tokenizer,
+ max_tokens=max_html_tokens,
+ warn_after_attempts=False,
+ allow_iterative_reduction=allow_iterative_reduction,
+ )
+ html = trunc["tree_repr"]
+ sys_prompt = html + "\n" + sys_prompt
+ else:
+ html = ""
+
+ if cands_turn is not None:
+ if add_unused_len_to_cands:
+ # NOTE: Add the unused length to the candidates.
+ num_html_tokens = len(tokenizer.tokenize(html))
+ num_utter_tokens = len(tokenizer.tokenize(utterance_context)) # type: ignore # noqa: PGH003
+ if use_tokenizer_template:
+ if template_tokenizer is None:
+ msg = "template_tokenizer must be provided when use_tokenizer_template is True."
+ raise ValueError(msg)
+ prev_turns_merged_copy = deepcopy(prev_turns_merged)
+ if prev_turns_merged[0]["role"] == "assistant":
+ prev_turns_merged_copy.insert(0, {"role": "user", "content": ""})
+ num_prev_turns_tokens = len(
+ template_tokenizer.apply_chat_template(
+ [{"role": "system", "content": ""}, *prev_turns_merged_copy],
+ tokenize=True,
+ ),
+ )
+ else:
+ num_prev_turns_tokens = len(
+ tokenizer.tokenize(" ".join(prev_turns_text_list)),
+ )
+ remain_html_tokens = max_html_tokens - num_html_tokens
+ remain_utter_tokens = max_utterance_tokens - num_utter_tokens
+ remain_prev_turns_tokens = max_prev_turns_tokens - num_prev_turns_tokens
+ remain_tokens = remain_html_tokens + remain_utter_tokens + remain_prev_turns_tokens
+ # NOTE: Add the unused length to the max_candidates_tokens.
+ max_candidates_tokens += remain_tokens
+
+ cands_turn_trunc = multi_attempt_truncate_cands_turn(
+ cands_turn=cands_turn,
+ tokenizer=tokenizer,
+ max_tokens=max_candidates_tokens,
+ format_candidates_fn=format_candidates_fn,
+ warn_after_attempts=False,
+ allow_iterative_reduction=allow_iterative_reduction,
+ )
+ cand_str = format_candidates_fn(cands_turn_trunc, max_char_len=None) # type: ignore # noqa: PGH003
+ cand_prompt = candidate_prompt_template.format(candidate_str=cand_str)
+ sys_prompt += cand_prompt[:-1]
+
+ return [{"role": "system", "content": sys_prompt}, *prev_turns_merged]
+
+
+def __insert_empty_user_content_at_first(prompt: list) -> None:
+ if prompt[0]["role"] != "system":
+ msg = f"First prompt must be a system prompt. Got {prompt[0]['role']} instead."
+ raise ValueError(msg)
+
+ if prompt[1]["role"] != "user":
+ prompt.insert(1, {"role": "user", "content": ""})
+
+
+def insert_formatted_chat_into_records( # noqa: D103
+ records,
+ demos: list[Demonstration],
+ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
+ *,
+ include_output_target: bool = True,
+) -> list:
+ processed_records = deepcopy(records)
+ for i, record in enumerate(records):
+ __insert_empty_user_content_at_first(record["prompt"])
+
+ if include_output_target:
+ target = [{"role": "assistant", "content": record["output_target"]}]
+ combined = record["prompt"] + target
+ else:
+ combined = record["prompt"]
+
+ # NOTE: The `apply_chat_template` method of the tokenizer is required.
+ text = str(
+ tokenizer.apply_chat_template(
+ combined,
+ tokenize=False,
+ add_generation_prompt=False,
+ ),
+ )
+
+ processed_records[i]["text"] = text
+
+ processed_records[i]["tasks"] = next(
+ filter(lambda demo: demo.form["shortcode"] == record["demo_name"], demos),
+ ).form["tasks"]
+
+ return processed_records
diff --git a/src/llama/train.py b/src/llama/train.py
new file mode 100644
index 0000000..064254e
--- /dev/null
+++ b/src/llama/train.py
@@ -0,0 +1,185 @@
+import json
+import logging
+from pathlib import Path
+from typing import Any
+
+import huggingface_hub
+import hydra
+import torch
+import wandb
+from accelerate import Accelerator
+from datasets import Dataset
+from dotenv import load_dotenv
+from mergoo.models.modeling_llama import LlamaForCausalLM
+from omegaconf import OmegaConf
+from peft import LoraConfig # type: ignore # noqa: PGH003
+from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ BitsAndBytesConfig, # type: ignore # noqa: PGH003
+ PreTrainedTokenizer,
+ PreTrainedTokenizerFast,
+ TrainingArguments,
+)
+from trl import SFTTrainer
+from weblinx.utils import set_seed
+
+load_dotenv()
+
+
+@hydra.main(config_path="conf", config_name="config", version_base=None)
+def main(cfg) -> None:
+ logging.basicConfig(level=logging.INFO)
+ logger = logging.getLogger(__name__)
+
+ logger.info(OmegaConf.to_yaml(cfg))
+
+ huggingface_hub.login(token=cfg.huggingface.token)
+ wandb.login(key=cfg.wandb.key)
+
+ set_seed(cfg.seed)
+
+ model_save_dir = Path(cfg.model.save_dir).expanduser()
+ model_save_dir.mkdir(exist_ok=True, parents=True)
+
+ bnb_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.bfloat16,
+ bnb_4bit_use_double_quant=True,
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ cfg.model.base_name,
+ padding_side="right",
+ trust_remote_code=True,
+ )
+ tokenizer.pad_token = tokenizer.unk_token
+
+ accelerator = Accelerator() if cfg.train.accelerate.use else None
+ model = (
+ LlamaForCausalLM.from_pretrained(
+ cfg.model.base_name,
+ device_map={"": accelerator.process_index} if accelerator is not None else "auto",
+ torch_dtype=torch.bfloat16,
+ use_cache=False,
+ attn_implementation="flash_attention_2" if cfg.model.use_flash_attention_2 else None,
+ quantization_config=bnb_config if cfg.train.qlora.use else None,
+ )
+ if cfg.model.moe
+ else AutoModelForCausalLM.from_pretrained(
+ cfg.model.base_name,
+ device_map={"": accelerator.process_index} if accelerator is not None else "auto",
+ torch_dtype=torch.bfloat16,
+ use_cache=False,
+ attn_implementation="flash_attention_2" if cfg.model.use_flash_attention_2 else None,
+ quantization_config=bnb_config if cfg.train.qlora.use else None,
+ )
+ )
+ if cfg.model.freeze.use:
+ for name, param in model.named_parameters(): # type: ignore # noqa: PGH003
+ if any(layer in name for layer in cfg.model.freeze.trainable_layers):
+ param.requires_grad = True
+ else:
+ param.requires_grad = False
+
+ with Path.open(
+ model_save_dir.joinpath(
+ f"{cfg.train.split}/{cfg.train.domain if cfg.train.domain else ''}/input_records.json",
+ ),
+ "r",
+ ) as f:
+ train_input_records = json.load(f)
+ with Path.open(
+ model_save_dir.joinpath(
+ f"{cfg.eval.split}/{cfg.eval.domain if cfg.eval.domain else ''}/input_records.json",
+ ),
+ "r",
+ ) as f:
+ eval_input_records = json.load(f)
+
+ train_input_texts = [{"text": record["text"]} for record in train_input_records]
+ eval_input_texts = [{"text": record["text"]} for record in eval_input_records]
+
+ train(
+ cfg,
+ model,
+ tokenizer,
+ train_input_texts,
+ eval_input_texts,
+ model_save_dir,
+ )
+
+
+def train( # noqa: PLR0913
+ cfg,
+ model,
+ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
+ train_input_texts: list[dict[str, Any]],
+ eval_input_texts: list[dict[str, Any]],
+ model_save_dir: Path,
+) -> None:
+ peft_config = LoraConfig(
+ r=cfg.train.qlora.r,
+ lora_alpha=cfg.train.qlora.alpha,
+ lora_dropout=cfg.train.qlora.dropout,
+ bias=cfg.train.qlora.bias,
+ task_type="CAUSAL_LM",
+ target_modules=OmegaConf.to_container(cfg.train.qlora.target_modules, resolve=True), # type: ignore # noqa: PGH003
+ )
+
+ training_args = TrainingArguments(
+ output_dir=str(model_save_dir),
+ num_train_epochs=cfg.train.num_epochs,
+ learning_rate=cfg.train.learning_rate,
+ per_device_train_batch_size=cfg.train.batch_size_per_device,
+ gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,
+ gradient_checkpointing=cfg.train.gradient_checkpointing,
+ gradient_checkpointing_kwargs={"use_reentrant": False},
+ max_grad_norm=cfg.train.max_grad_norm,
+ optim=cfg.train.optim,
+ weight_decay=cfg.train.weight_decay,
+ lr_scheduler_type=cfg.train.scheduler,
+ warmup_steps=cfg.train.warmup_steps,
+ warmup_ratio=cfg.train.warmup_ratio,
+ save_strategy="steps",
+ save_steps=100,
+ eval_strategy="steps",
+ per_device_eval_batch_size=cfg.eval.batch_size_per_device,
+ eval_accumulation_steps=cfg.eval.gradient_accumulation_steps,
+ eval_steps=100,
+ logging_strategy="steps",
+ logging_steps=10,
+ logging_first_step=True,
+ bf16=True,
+ bf16_full_eval=True,
+ group_by_length=True,
+ prediction_loss_only=True,
+ metric_for_best_model="eval_loss",
+ load_best_model_at_end=True,
+ torch_compile=True,
+ report_to="wandb",
+ ) # type: ignore # noqa: PGH003
+
+ trainer = SFTTrainer(
+ model=model,
+ tokenizer=tokenizer,
+ args=training_args, # type: ignore # noqa: PGH003
+ train_dataset=Dataset.from_list(train_input_texts),
+ eval_dataset=Dataset.from_list(eval_input_texts),
+ max_seq_length=model.config.max_position_embeddings,
+ dataset_text_field="text",
+ peft_config=peft_config if cfg.train.qlora.use else None,
+ )
+
+ wandb.init(project=cfg.wandb.project)
+
+ trainer.train() # type: ignore # noqa: PGH003
+
+ trainer.save_model(str(model_save_dir))
+ tokenizer.save_pretrained(model_save_dir)
+ trainer.state.save_to_json(str(model_save_dir.joinpath("trainer_state.json")))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/webnavix.code-workspace b/webnavix.code-workspace
new file mode 100644
index 0000000..a26ff12
--- /dev/null
+++ b/webnavix.code-workspace
@@ -0,0 +1,8 @@
+{
+ "folders": [
+ {
+ "name": "webnavix",
+ "path": "."
+ }
+ ]
+}