From 5fef94900d3ae882f195964fe310d0878bfe4583 Mon Sep 17 00:00:00 2001 From: shio <85730998+dino3616@users.noreply.github.com> Date: Fri, 30 Aug 2024 22:05:55 +0900 Subject: [PATCH] =?UTF-8?q?initial:=20=F0=9F=8E=89=20first=20commit?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .devcontainer/devcontainer.json | 20 ++ .dockerignore | 27 ++ .editorconfig | 12 + .github/workflows/app-test.yaml | 19 + .gitignore | 32 ++ .mypy.ini | 26 ++ .python-version | 1 + .vscode/settings.json | 159 +++++++++ README.md | 1 + docker/Dockerfile.cloud-machine | 27 ++ docker/Dockerfile.development | 26 ++ docker/docker-compose.development.yaml | 13 + lefthook.yaml | 9 + pyproject.toml | 50 +++ requirements-dev.lock | 335 ++++++++++++++++++ requirements.lock | 333 +++++++++++++++++ ruff.toml | 43 +++ src/README.md | 280 +++++++++++++++ src/__init__.py | 0 src/dataset.py | 9 + src/llama/__init__.py | 0 src/llama/accelerate/fsdp_13b.yaml | 28 ++ src/llama/accelerate/fsdp_7b.yaml | 28 ++ src/llama/conf/config.yaml | 65 ++++ src/llama/conf/variant/ft_1.3b.yaml | 11 + src/llama/conf/variant/ft_13b.yaml | 18 + src/llama/conf/variant/ft_2.7b.yaml | 11 + src/llama/conf/variant/ft_7b.yaml | 18 + .../conf/variant/ft_llama3_8b_instruct.yaml | 18 + src/llama/eval.py | 155 ++++++++ src/llama/processing.py | 324 +++++++++++++++++ src/llama/train.py | 208 +++++++++++ webnavix.code-workspace | 8 + 33 files changed, 2314 insertions(+) create mode 100644 .devcontainer/devcontainer.json create mode 100644 .dockerignore create mode 100644 .editorconfig create mode 100644 .github/workflows/app-test.yaml create mode 100644 .gitignore create mode 100644 .mypy.ini create mode 100644 .python-version create mode 100644 .vscode/settings.json create mode 100644 README.md create mode 100644 docker/Dockerfile.cloud-machine create mode 100644 docker/Dockerfile.development create mode 100644 docker/docker-compose.development.yaml create mode 100644 lefthook.yaml create mode 100644 pyproject.toml create mode 100644 requirements-dev.lock create mode 100644 requirements.lock create mode 100644 ruff.toml create mode 100644 src/README.md create mode 100644 src/__init__.py create mode 100644 src/dataset.py create mode 100644 src/llama/__init__.py create mode 100644 src/llama/accelerate/fsdp_13b.yaml create mode 100644 src/llama/accelerate/fsdp_7b.yaml create mode 100644 src/llama/conf/config.yaml create mode 100644 src/llama/conf/variant/ft_1.3b.yaml create mode 100644 src/llama/conf/variant/ft_13b.yaml create mode 100644 src/llama/conf/variant/ft_2.7b.yaml create mode 100644 src/llama/conf/variant/ft_7b.yaml create mode 100644 src/llama/conf/variant/ft_llama3_8b_instruct.yaml create mode 100644 src/llama/eval.py create mode 100644 src/llama/processing.py create mode 100644 src/llama/train.py create mode 100644 webnavix.code-workspace diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..27baf19 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,20 @@ +{ + "name": "webnavix", + "workspaceFolder": "/workspaces/webnavix/", + "dockerComposeFile": ["../docker/docker-compose.development.yaml"], + "service": "app", + "customizations": { + "vscode": { + "extensions": [ + "adam-bender.commit-message-editor", + "charliermarsh.ruff", + "eamodio.gitlens", + "EditorConfig.EditorConfig", + "esbenp.prettier-vscode", + "ms-python.python", + "tamasfe.even-better-toml", + "VisualStudioExptTeam.vscodeintellicode" + ] + } + } +} diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..67ddee0 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,27 @@ +# cache +**/__pycache__/ +**/.mypy_cache/ +**/.ruff_cache/ +**/*.egg-info/ + +# dataset +**/wl_data/ + +# debug +**/*log* + +# deliverable +**/build/ +**/dist/ +**/out/ + +# dependency +**/.venv/ + +# env file +**/.env* +!**/.env.example + +# misc +**/.DS_Store +**/*.pem diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..4a7ea30 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,12 @@ +root = true + +[*] +indent_style = space +indent_size = 2 +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.md] +trim_trailing_whitespace = false diff --git a/.github/workflows/app-test.yaml b/.github/workflows/app-test.yaml new file mode 100644 index 0000000..96dc403 --- /dev/null +++ b/.github/workflows/app-test.yaml @@ -0,0 +1,19 @@ +name: app test + +on: push + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: checkout + uses: actions/checkout@v4 + + - name: setup rye + uses: eifinger/setup-rye@v4 + + - name: install dependencies + run: rye sync + + - name: check + run: rye check diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea5032b --- /dev/null +++ b/.gitignore @@ -0,0 +1,32 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# cache +__pycache__/ +.mypy_cache/ +.ruff_cache/ +*.egg-info/ + +# dataset +wl_data/ + +# debug +*log* +wandb/ + +# deliverable +build/ +checkpoints/ +dist/ +out/ +results/ + +# dependency +.venv/ + +# env file +.env* +!.env.example + +# misc +.DS_Store +*.pem diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 0000000..a929c95 --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,26 @@ +[mypy] +allow_redefinition = True +allow_untyped_globals = False +check_untyped_defs = True +color_output = True +disallow_incomplete_defs = True +disallow_untyped_calls = False +disallow_untyped_decorators = False +disallow_untyped_defs = True +error_summary = True +ignore_missing_imports = True +implicit_reexport = True +namespace_packages = True +no_implicit_optional = True +pretty = True +show_column_numbers = True +show_error_codes = True +show_error_context = True +show_traceback = True +strict = True +warn_no_return = True +warn_redundant_casts = True +warn_return_any = True +warn_unreachable = True +warn_unused_configs = True +warn_unused_ignores = False diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..871f80a --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12.3 diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..13d3842 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,159 @@ +{ + "commit-message-editor.tokens": [ + { + "label": "Type", + "name": "type", + "type": "enum", + "description": "Type of changes.", + "combobox": true, + "options": [ + { + "label": "feat: โœจ", + "value": "feat: โœจ", + "description": "Implementation of new features." + }, + { + "label": "feat: ๐ŸŽˆ", + "value": "feat: ๐ŸŽˆ", + "description": "Repair of existing features." + }, + { + "label": "feat: โšฐ๏ธ", + "value": "feat: โšฐ๏ธ", + "description": "Deletion of features." + }, + { + "label": "fix: ๐Ÿ›", + "value": "fix: ๐Ÿ›", + "description": "Bug fixes." + }, + { + "label": "fix: ๐Ÿš‘๏ธ", + "value": "fix: ๐Ÿš‘๏ธ", + "description": "Critical bug fixes or major changes." + }, + { + "label": "doc: ๐Ÿ“", + "value": "doc: ๐Ÿ“", + "description": "Documentation changes." + }, + { + "label": "typo: ๐Ÿ–‹๏ธ", + "value": "typo: ๐Ÿ–‹๏ธ", + "description": "Typography changes." + }, + { + "label": "style: ๐Ÿ’„", + "value": "style: ๐Ÿ’„", + "description": "Style changes." + }, + { + "label": "refactor: โ™ป๏ธ", + "value": "refactor: โ™ป๏ธ", + "description": "Code formatting or refactoring." + }, + { + "label": "test: ๐Ÿงช", + "value": "test: ๐Ÿงช", + "description": "Test cases changes." + }, + { + "label": "ci: ๐Ÿฆบ", + "value": "ci: ๐Ÿฆบ", + "description": "CI changes." + }, + { + "label": "build: ๐Ÿ“ฆ๏ธ", + "value": "build: ๐Ÿ“ฆ๏ธ", + "description": "Build system or dependency changes." + }, + { + "label": "container: ๐Ÿณ", + "value": "container: ๐Ÿณ", + "description": "The Dockerfile changes." + }, + { + "label": "container: ๐Ÿ™", + "value": "container: ๐Ÿ™", + "description": "The docker-compose changes." + }, + { + "label": "chore: ๐Ÿ”ง", + "value": "chore: ๐Ÿ”ง", + "description": "Configuration changes." + }, + { + "label": "chore: ๐Ÿ”จ", + "value": "chore: ๐Ÿ”จ", + "description": "Development script changes." + }, + { + "label": "chore: ๐Ÿฑ", + "value": "chore: ๐Ÿฑ", + "description": "Assets changes." + }, + { + "label": "revert: โช๏ธ", + "value": "revert: โช๏ธ", + "description": "Reversion of changes." + }, + { + "label": "wip: ๐Ÿšง", + "value": "wip: ๐Ÿšง", + "description": "Changes that will be squashed." + }, + { + "label": "initial: ๐ŸŽ‰", + "value": "initial: ๐ŸŽ‰", + "description": "The first commit." + } + ] + }, + { + "label": "Scope", + "name": "scope", + "type": "text", + "description": "Scope of changes.", + "prefix": " (", + "suffix": ")" + }, + { + "label": "Short Description", + "name": "description", + "type": "text", + "description": "Commit summary.", + "prefix": " " + }, + { + "label": "Body", + "name": "body", + "type": "text", + "description": "Detailed description of commit.", + "maxLines": 10, + "multiline": true, + "lines": 5 + }, + { + "label": "Footer", + "name": "footer", + "description": "Description of disruptive changes or signature.", + "type": "text", + "multiline": true + } + ], + "commit-message-editor.dynamicTemplate": ["{type}{scope}{description}", "", "{body}", "", "{footer}"], + "commit-message-editor.staticTemplate": ["label: emoji (scope) short-description", "", "body", "", "footer"], + "commit-message-editor.view.defaultView": "form", + "editor.defaultFormatter": "esbenp.prettier-vscode", + "files.encoding": "utf8", + "files.eol": "\n", + "python.analysis.typeCheckingMode": "basic", + "python.defaultInterpreterPath": "/opt/rye/shims/python", + "[python]": { + "editor.codeActionsOnSave": { + "source.fixAll.ruff": "explicit", + "source.organizeImports.ruff": "explicit" + }, + "editor.defaultFormatter": "charliermarsh.ruff" + } +} diff --git a/README.md b/README.md new file mode 100644 index 0000000..0b407df --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# WebNavix diff --git a/docker/Dockerfile.cloud-machine b/docker/Dockerfile.cloud-machine new file mode 100644 index 0000000..d55d142 --- /dev/null +++ b/docker/Dockerfile.cloud-machine @@ -0,0 +1,27 @@ +FROM python:3.12 + +SHELL ["/bin/bash", "-o", "pipefail", "-c"] + +# hadolint ignore=DL3008 +RUN apt-get update \ + && apt-get --no-install-recommends -y install git gnupg2 ca-certificates curl pipx \ + && pipx ensurepath \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists + +RUN curl -sSf https://rye.astral.sh/get | RYE_INSTALL_OPTION="--yes" bash \ + && echo "source '$HOME/.rye/env'" >> ~/.bashrc \ + && /root/.rye/shims/rye config --set-bool behavior.global-python=true \ + && /root/.rye/shims/rye config --set-bool behavior.use-uv=true + +RUN RYE_UV_HOME=$(find "$HOME/.rye/uv" -type d -regex '.*/[0-9]+\.[0-9]+\.[0-9]+$') \ + && echo "export PATH=\"$RYE_UV_HOME:\$PATH\"" >> ~/.bashrc + +WORKDIR /workspaces/webnavix/ + +RUN git clone "https://github.com/nitic-nlp-team/webnavix.git" + +# hadolint ignore=SC1091 +RUN "$HOME/.rye/shims/rye" pin "$(cat ./.python-version)" && "$HOME/.rye/shims/rye" sync && source ./.venv/bin/activate + +RUN "$HOME/.rye/shims/rye" run python ./src/dataset.py diff --git a/docker/Dockerfile.development b/docker/Dockerfile.development new file mode 100644 index 0000000..4c54fe1 --- /dev/null +++ b/docker/Dockerfile.development @@ -0,0 +1,26 @@ +FROM python:3.12 + +SHELL ["/bin/bash", "-o", "pipefail", "-c"] + +# hadolint ignore=DL3008 +RUN apt-get update \ + && apt-get --no-install-recommends -y install git gnupg2 ca-certificates curl pipx \ + && pipx ensurepath \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists + +RUN curl -sSf https://rye.astral.sh/get | RYE_INSTALL_OPTION="--yes" bash \ + && echo "source '$HOME/.rye/env'" >> ~/.bashrc \ + && /root/.rye/shims/rye config --set-bool behavior.global-python=true \ + && /root/.rye/shims/rye config --set-bool behavior.use-uv=true + +RUN RYE_UV_HOME=$(find "$HOME/.rye/uv" -type d -regex '.*/[0-9]+\.[0-9]+\.[0-9]+$') \ + && echo "export PATH=\"$RYE_UV_HOME:\$PATH\"" >> ~/.bashrc + +WORKDIR /workspaces/webnavix/ + +COPY ./.python-version ./pyproject.toml ./requirements* ./ +# hadolint ignore=SC1091 +RUN "$HOME/.rye/shims/rye" pin "$(cat ./.python-version)" && "$HOME/.rye/shims/rye" sync && source ./.venv/bin/activate + +COPY ./ ./ diff --git a/docker/docker-compose.development.yaml b/docker/docker-compose.development.yaml new file mode 100644 index 0000000..4d2eb25 --- /dev/null +++ b/docker/docker-compose.development.yaml @@ -0,0 +1,13 @@ +services: + app: + container_name: webnavix + build: + context: ../ + dockerfile: ./docker/Dockerfile.development + volumes: + - type: bind + source: ../ + target: /workspaces/webnavix/ + environment: + PROJECT_DIR: /workspaces/webnavix/ + tty: true diff --git a/lefthook.yaml b/lefthook.yaml new file mode 100644 index 0000000..9de67fa --- /dev/null +++ b/lefthook.yaml @@ -0,0 +1,9 @@ +pre-commit: + parallel: true + commands: + check-py: + glob: "*.*{py}*" + run: ruff check --fix {staged_files} + format-py: + glob: "*.*{py}*" + run: ruff format --fix {staged_files} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6cfc38c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,50 @@ +[project] +name = "webnavix" +version = "0.1.0" +description = "Add your description here" +authors = [ + { name = "shio", email = "85730998+dino3616@users.noreply.github.com" }, +] +dependencies = [ + "accelerate>=0.32.1", + "backoff>=2.2.1", + "bert-score>=0.3.13", + "bitsandbytes>=0.42.0", + "coloredlogs>=15.0.1", + "datasets>=2.20.0", + "huggingface-hub>=0.24.5", + "hydra-core>=1.3.2", + "lxml>=5.2.2", + "ninja>=1.11.1.1", + "numpy>=1.26.4", + "openai>=1.35.15", + "optimum>=1.21.2", + "packaging>=24.1", + "pandas>=2.2.2", + "peft>=0.11.1", + "pillow>=10.4.0", + "python-dotenv>=1.0.1", + "sacrebleu>=2.4.2", + "sentence-transformers>=3.0.1", + "setuptools>=71.1.0", + "tensorboardx>=2.6.2.2", + "tiktoken>=0.7.0", + "torch>=2.3.1", + "tqdm>=4.66.4", + "transformers>=4.42.4", + "trl>=0.9.6", + "wandb>=0.17.7", + "weblinx>=0.3.0", +] +readme = "README.md" +requires-python = ">= 3.12" + +[tool.rye] +managed = true +dev-dependencies = ["ruff>=0.5.3", "lefthook>=0.1.2"] + +[tool.rye.scripts] +check = { chain = ["lint", "lint:type", "fmt"] } +"lint" = "ruff check ./ --diff" +"lint:type" = "mypy ./ --explicit-package-bases" +"fmt" = "ruff fmt ./" diff --git a/requirements-dev.lock b/requirements-dev.lock new file mode 100644 index 0000000..631bcd8 --- /dev/null +++ b/requirements-dev.lock @@ -0,0 +1,335 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false +# generate-hashes: false +# universal: false + +-e file:. +accelerate==0.32.1 + # via peft + # via trl + # via webnavix +aiohttp==3.9.5 + # via datasets + # via fsspec +aiosignal==1.3.1 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +antlr4-python3-runtime==4.9.3 + # via hydra-core + # via omegaconf +anyio==4.4.0 + # via httpx + # via openai +attrs==23.2.0 + # via aiohttp +backoff==2.2.1 + # via webnavix +bert-score==0.3.13 + # via webnavix +bitsandbytes==0.42.0 + # via webnavix +certifi==2024.7.4 + # via httpcore + # via httpx + # via requests + # via sentry-sdk +charset-normalizer==3.3.2 + # via requests +click==8.1.7 + # via wandb +colorama==0.4.6 + # via sacrebleu +coloredlogs==15.0.1 + # via optimum + # via webnavix +contourpy==1.2.1 + # via matplotlib +cycler==0.12.1 + # via matplotlib +datasets==2.20.0 + # via optimum + # via trl + # via webnavix +dill==0.3.8 + # via datasets + # via multiprocess +distro==1.9.0 + # via openai +docker-pycreds==0.4.0 + # via wandb +docstring-parser==0.16 + # via tyro +filelock==3.15.4 + # via datasets + # via huggingface-hub + # via torch + # via transformers +fonttools==4.53.1 + # via matplotlib +frozenlist==1.4.1 + # via aiohttp + # via aiosignal +fsspec==2024.5.0 + # via datasets + # via huggingface-hub + # via torch +gitdb==4.0.11 + # via gitpython +gitpython==3.1.43 + # via wandb +h11==0.14.0 + # via httpcore +httpcore==1.0.5 + # via httpx +httpx==0.27.0 + # via openai +huggingface-hub==0.24.5 + # via accelerate + # via datasets + # via optimum + # via peft + # via sentence-transformers + # via tokenizers + # via transformers + # via webnavix +humanfriendly==10.0 + # via coloredlogs +hydra-core==1.3.2 + # via webnavix +idna==3.7 + # via anyio + # via httpx + # via requests + # via yarl +jinja2==3.1.4 + # via torch +joblib==1.4.2 + # via scikit-learn +kiwisolver==1.4.5 + # via matplotlib +lefthook==0.1.2 +lxml==5.2.2 + # via sacrebleu + # via webnavix +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.5 + # via jinja2 +matplotlib==3.9.1 + # via bert-score +mdurl==0.1.2 + # via markdown-it-py +mpmath==1.3.0 + # via sympy +multidict==6.0.5 + # via aiohttp + # via yarl +multiprocess==0.70.16 + # via datasets +networkx==3.3 + # via torch +ninja==1.11.1.1 + # via webnavix +numpy==1.26.4 + # via accelerate + # via bert-score + # via contourpy + # via datasets + # via matplotlib + # via optimum + # via pandas + # via peft + # via pyarrow + # via sacrebleu + # via scikit-learn + # via scipy + # via sentence-transformers + # via tensorboardx + # via transformers + # via trl + # via webnavix +omegaconf==2.3.0 + # via hydra-core +openai==1.35.15 + # via webnavix +optimum==1.21.2 + # via webnavix +packaging==24.1 + # via accelerate + # via bert-score + # via datasets + # via huggingface-hub + # via hydra-core + # via matplotlib + # via optimum + # via peft + # via tensorboardx + # via transformers + # via webnavix +pandas==2.2.2 + # via bert-score + # via datasets + # via webnavix +peft==0.11.1 + # via webnavix +pillow==10.4.0 + # via matplotlib + # via sentence-transformers + # via webnavix +platformdirs==4.2.2 + # via wandb +portalocker==2.10.1 + # via sacrebleu +protobuf==5.27.2 + # via tensorboardx + # via transformers + # via wandb +psutil==6.0.0 + # via accelerate + # via peft + # via wandb +pyarrow==17.0.0 + # via datasets +pyarrow-hotfix==0.6 + # via datasets +pydantic==2.8.2 + # via openai +pydantic-core==2.20.1 + # via pydantic +pygments==2.18.0 + # via rich +pyparsing==3.1.2 + # via matplotlib +python-dateutil==2.9.0.post0 + # via matplotlib + # via pandas +python-dotenv==1.0.1 + # via webnavix +pytz==2024.1 + # via pandas +pyyaml==6.0.1 + # via accelerate + # via datasets + # via huggingface-hub + # via omegaconf + # via peft + # via transformers + # via wandb +regex==2024.5.15 + # via sacrebleu + # via tiktoken + # via transformers +requests==2.32.3 + # via bert-score + # via datasets + # via huggingface-hub + # via tiktoken + # via transformers + # via wandb +rich==13.7.1 + # via tyro +ruff==0.5.3 +sacrebleu==2.4.2 + # via webnavix +safetensors==0.4.3 + # via accelerate + # via peft + # via transformers +scikit-learn==1.5.1 + # via sentence-transformers +scipy==1.14.0 + # via bitsandbytes + # via scikit-learn + # via sentence-transformers +sentence-transformers==3.0.1 + # via webnavix +sentencepiece==0.2.0 + # via transformers +sentry-sdk==2.13.0 + # via wandb +setproctitle==1.3.3 + # via wandb +setuptools==71.1.0 + # via torch + # via wandb + # via webnavix +shtab==1.7.1 + # via tyro +six==1.16.0 + # via docker-pycreds + # via python-dateutil +smmap==5.0.1 + # via gitdb +sniffio==1.3.1 + # via anyio + # via httpx + # via openai +sympy==1.13.1 + # via optimum + # via torch +tabulate==0.9.0 + # via sacrebleu +tensorboardx==2.6.2.2 + # via webnavix +threadpoolctl==3.5.0 + # via scikit-learn +tiktoken==0.7.0 + # via webnavix +tokenizers==0.19.1 + # via transformers +torch==2.4.0 + # via accelerate + # via bert-score + # via optimum + # via peft + # via sentence-transformers + # via trl + # via webnavix +tqdm==4.66.4 + # via bert-score + # via datasets + # via huggingface-hub + # via openai + # via peft + # via sentence-transformers + # via transformers + # via weblinx + # via webnavix +transformers==4.42.4 + # via bert-score + # via optimum + # via peft + # via sentence-transformers + # via trl + # via webnavix +trl==0.9.6 + # via webnavix +typing-extensions==4.12.2 + # via huggingface-hub + # via openai + # via pydantic + # via pydantic-core + # via torch + # via tyro +tyro==0.8.5 + # via trl +tzdata==2024.1 + # via pandas +urllib3==2.2.2 + # via requests + # via sentry-sdk +wandb==0.17.7 + # via webnavix +weblinx==0.3.0 + # via webnavix +xxhash==3.4.1 + # via datasets +yarl==1.9.4 + # via aiohttp diff --git a/requirements.lock b/requirements.lock new file mode 100644 index 0000000..3994e9a --- /dev/null +++ b/requirements.lock @@ -0,0 +1,333 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false +# generate-hashes: false +# universal: false + +-e file:. +accelerate==0.32.1 + # via peft + # via trl + # via webnavix +aiohttp==3.9.5 + # via datasets + # via fsspec +aiosignal==1.3.1 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +antlr4-python3-runtime==4.9.3 + # via hydra-core + # via omegaconf +anyio==4.4.0 + # via httpx + # via openai +attrs==23.2.0 + # via aiohttp +backoff==2.2.1 + # via webnavix +bert-score==0.3.13 + # via webnavix +bitsandbytes==0.42.0 + # via webnavix +certifi==2024.7.4 + # via httpcore + # via httpx + # via requests + # via sentry-sdk +charset-normalizer==3.3.2 + # via requests +click==8.1.7 + # via wandb +colorama==0.4.6 + # via sacrebleu +coloredlogs==15.0.1 + # via optimum + # via webnavix +contourpy==1.2.1 + # via matplotlib +cycler==0.12.1 + # via matplotlib +datasets==2.20.0 + # via optimum + # via trl + # via webnavix +dill==0.3.8 + # via datasets + # via multiprocess +distro==1.9.0 + # via openai +docker-pycreds==0.4.0 + # via wandb +docstring-parser==0.16 + # via tyro +filelock==3.15.4 + # via datasets + # via huggingface-hub + # via torch + # via transformers +fonttools==4.53.1 + # via matplotlib +frozenlist==1.4.1 + # via aiohttp + # via aiosignal +fsspec==2024.5.0 + # via datasets + # via huggingface-hub + # via torch +gitdb==4.0.11 + # via gitpython +gitpython==3.1.43 + # via wandb +h11==0.14.0 + # via httpcore +httpcore==1.0.5 + # via httpx +httpx==0.27.0 + # via openai +huggingface-hub==0.24.5 + # via accelerate + # via datasets + # via optimum + # via peft + # via sentence-transformers + # via tokenizers + # via transformers + # via webnavix +humanfriendly==10.0 + # via coloredlogs +hydra-core==1.3.2 + # via webnavix +idna==3.7 + # via anyio + # via httpx + # via requests + # via yarl +jinja2==3.1.4 + # via torch +joblib==1.4.2 + # via scikit-learn +kiwisolver==1.4.5 + # via matplotlib +lxml==5.2.2 + # via sacrebleu + # via webnavix +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.5 + # via jinja2 +matplotlib==3.9.1 + # via bert-score +mdurl==0.1.2 + # via markdown-it-py +mpmath==1.3.0 + # via sympy +multidict==6.0.5 + # via aiohttp + # via yarl +multiprocess==0.70.16 + # via datasets +networkx==3.3 + # via torch +ninja==1.11.1.1 + # via webnavix +numpy==1.26.4 + # via accelerate + # via bert-score + # via contourpy + # via datasets + # via matplotlib + # via optimum + # via pandas + # via peft + # via pyarrow + # via sacrebleu + # via scikit-learn + # via scipy + # via sentence-transformers + # via tensorboardx + # via transformers + # via trl + # via webnavix +omegaconf==2.3.0 + # via hydra-core +openai==1.35.15 + # via webnavix +optimum==1.21.2 + # via webnavix +packaging==24.1 + # via accelerate + # via bert-score + # via datasets + # via huggingface-hub + # via hydra-core + # via matplotlib + # via optimum + # via peft + # via tensorboardx + # via transformers + # via webnavix +pandas==2.2.2 + # via bert-score + # via datasets + # via webnavix +peft==0.11.1 + # via webnavix +pillow==10.4.0 + # via matplotlib + # via sentence-transformers + # via webnavix +platformdirs==4.2.2 + # via wandb +portalocker==2.10.1 + # via sacrebleu +protobuf==5.27.2 + # via tensorboardx + # via transformers + # via wandb +psutil==6.0.0 + # via accelerate + # via peft + # via wandb +pyarrow==17.0.0 + # via datasets +pyarrow-hotfix==0.6 + # via datasets +pydantic==2.8.2 + # via openai +pydantic-core==2.20.1 + # via pydantic +pygments==2.18.0 + # via rich +pyparsing==3.1.2 + # via matplotlib +python-dateutil==2.9.0.post0 + # via matplotlib + # via pandas +python-dotenv==1.0.1 + # via webnavix +pytz==2024.1 + # via pandas +pyyaml==6.0.1 + # via accelerate + # via datasets + # via huggingface-hub + # via omegaconf + # via peft + # via transformers + # via wandb +regex==2024.5.15 + # via sacrebleu + # via tiktoken + # via transformers +requests==2.32.3 + # via bert-score + # via datasets + # via huggingface-hub + # via tiktoken + # via transformers + # via wandb +rich==13.7.1 + # via tyro +sacrebleu==2.4.2 + # via webnavix +safetensors==0.4.3 + # via accelerate + # via peft + # via transformers +scikit-learn==1.5.1 + # via sentence-transformers +scipy==1.14.0 + # via bitsandbytes + # via scikit-learn + # via sentence-transformers +sentence-transformers==3.0.1 + # via webnavix +sentencepiece==0.2.0 + # via transformers +sentry-sdk==2.13.0 + # via wandb +setproctitle==1.3.3 + # via wandb +setuptools==71.1.0 + # via torch + # via wandb + # via webnavix +shtab==1.7.1 + # via tyro +six==1.16.0 + # via docker-pycreds + # via python-dateutil +smmap==5.0.1 + # via gitdb +sniffio==1.3.1 + # via anyio + # via httpx + # via openai +sympy==1.13.1 + # via optimum + # via torch +tabulate==0.9.0 + # via sacrebleu +tensorboardx==2.6.2.2 + # via webnavix +threadpoolctl==3.5.0 + # via scikit-learn +tiktoken==0.7.0 + # via webnavix +tokenizers==0.19.1 + # via transformers +torch==2.4.0 + # via accelerate + # via bert-score + # via optimum + # via peft + # via sentence-transformers + # via trl + # via webnavix +tqdm==4.66.4 + # via bert-score + # via datasets + # via huggingface-hub + # via openai + # via peft + # via sentence-transformers + # via transformers + # via weblinx + # via webnavix +transformers==4.42.4 + # via bert-score + # via optimum + # via peft + # via sentence-transformers + # via trl + # via webnavix +trl==0.9.6 + # via webnavix +typing-extensions==4.12.2 + # via huggingface-hub + # via openai + # via pydantic + # via pydantic-core + # via torch + # via tyro +tyro==0.8.5 + # via trl +tzdata==2024.1 + # via pandas +urllib3==2.2.2 + # via requests + # via sentry-sdk +wandb==0.17.7 + # via webnavix +weblinx==0.3.0 + # via webnavix +xxhash==3.4.1 + # via datasets +yarl==1.9.4 + # via aiohttp diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..a7429e2 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,43 @@ +target-version = "py312" +indent-width = 4 +line-length = 120 +exclude = [ + "__pycache__", + ".mypy_cache", + ".ruff_cache", + ".venv", + "*.egg-info", + "build", + "checkpoints", + "dist", + "out", + "wl_data", + "*log*", +] + +[format] +indent-style = "space" +line-ending = "auto" +quote-style = "double" +skip-magic-trailing-comma = false + + +[lint] +select = ["ALL"] +fixable = ["ALL"] +ignore = [ + "ANN001", + "EM101", + "ERA001", + "FBT001", + "FBT002", + "RET504", + "TRY002", + "TRY003", +] + +[extend-per-file-ignores] +"__init__.py" = ["D1", "F403"] +"dataset.py" = ["D1"] +"eval.py" = ["D1"] +"train.py" = ["D1"] diff --git a/src/README.md b/src/README.md new file mode 100644 index 0000000..5c5dc6a --- /dev/null +++ b/src/README.md @@ -0,0 +1,280 @@ +The following instructions assume you are running from this directory (you may need to `cd` to this directory). + +### Download Data + +First, you need to download the `splits.json` file containing information about all the splits, as well as the `train.jsonl` candidate selected by `McGill-NLP/MiniLM-L6-DMR`: + +```python +from huggingface_hub import snapshot_download + +# splits.json +snapshot_download( + repo_id="McGill-NLP/WebLINX-full", repo_type="dataset", allow_patterns="splits.json", local_dir="./wl_data/" +) + +# candidates files +snapshot_download( + repo_id="McGill-NLP/WebLINX-full", + repo_type="dataset", + allow_patterns="candidates/*.jsonl", + local_dir="./wl_data/" +) +``` + +Download the full dataset (warning: this will take a while): + +```python +from huggingface_hub import snapshot_download + +snapshot_download(repo_id="McGill-NLP/WebLINX-full", repo_type="dataset", local_dir="./wl_data/") +``` + +The default configs (`llama/conf/config.yml`) assume that the `train.jsonl` is located at `./wl_data/candidates/train.jsonl`. If you want to change the path, you need to modify the `config.yml` accordingly. + +### Set `PROJECT_DIR` + +You need to set the `PROJECT_DIR` environment variable to the root directory of the WebLINX project. For example, if you have the following directory structure: + +```bash +export PROJECT_DIR=/path/to/the/modeling/directory/ + +# For example, if you are in the modeling directory, you can run: +export PROJECT_DIR=$(pwd) +``` + +### Install Dependencies + +You need to install the dependencies by running the following command: + +```bash +pip install -r requirements.txt +``` + +However, due to `flash-attention` requiring `torch` to be pre-installed, it has to be install right after everything else has been installed: +```bash +pip install wheel +# Regular install +pip install "flash-attn>=2.3.0" +# IF you have limited RAM, you can try this: +MAX_JOBS=4 pip install "flash-attn>=2.3.0" --no-build-isolation +# If you have issues with nvcc, try this: +FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install "flash-attn>=2.3.0" --no-build-isolation +``` + +### Optional: Symbolic linking to `WebLINX-full` + +If you downloaded `WebLINX-full` data in a different location (e.g. different disk) from your `weblinx/modeling` directory, you might consider using symbolic link to avoid having to change the `config.yml` files. You should do something like: + +```bash +ln -s /location/of/your/full/data /location/of/project/weblinx/modeling/wl_data +``` + +For example, if your data is located at `/mnt/research/scratch/users/jdoe/WebLINX-full` but your cloned `weblinx` repository is at `~/dev/weblinx`, then you'd run: + +```bash +ln -s /mnt/research/scratch/users/jdoe/WebLINX-full ~/dev/weblinx/modeling/wl_data +``` + +Which corresponds to the `data.base_dir` specified in `config.yml`, which is `"${project_dir}/wl_data/demonstrations/"`. + +### Dense Markup Ranking (DMR) + +#### Train DMR + +You can train the model by running the following command (it will automatically use the hydra config from `conf/`): + +```bash +export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use + +# Finetune MiniLM-L6-DMR (Default) +python -m dmr.train + +# Finetune variant gte or bge +python -m dmr.train +variant=gte +python -m dmr.train +variant=bge +``` + +Results will be saved in `./results` and checkpoints in `./checkpoints`. + +#### Inference for DMR + +You need to specify which `eval.split` you want to evaluate on. For example, to evaluate on the `iid` split, you can run the following command: + +```bash +export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use + +# On just one +python -m dmr.eval eval.split=valid + +# On multiple splits (e.g. test_iid, test_vis) +python -m dmr.eval eval.split=test_iid,test_web,test_geo,test_cat,test_vis + +# Or for bge, gte +python -m dmr.eval +variant=gte eval.split=test_iid,test_web,test_geo,test_cat,test_vis +python -m dmr.eval +variant=bge eval.split=test_iid,test_web,test_geo,test_cat,test_vis +``` + +#### Moving generated DMR results to `wl_data/candidates` + +The `scores.jsonl` and `results.json` files will be saved at the `cfg.eval.result_dir` variable in `modeling/dmr/conf/config.yml`, which is by default `${project_dir}/results/${project_name}/${model.name}/${eval.split}`, which should by default resolve to `/path/to/weblinx/modeling/results/dmr/sentence-transformers/all-MiniLM-L6-v2/train` for the `train` split, `.../valid` for the valid split, etc. However, since the next steps assumes you have a directory like `wl_data/candidates/.json`, you need to manually move it. For example, you could run: + +```bash +# Change the following paths to match your setup +orig_dir="/path/to/weblinx/modeling/results/dmr/sentence-transformers/all-MiniLM-L6-v2" + +# This is the directory where the candidates are stored +new_dir="/path/to/wl_data/candidates" + +# You need to move the train split if you plan to use it for training the action model +mv $orig_dir/train/scores.jsonl $new_dir/train.jsonl + +# You can move valid and test IID splits as well +mv $orig_dir/valid/scores.jsonl $new_dir/valid.jsonl +mv $orig_dir/test_iid/scores.jsonl $new_dir/test_iid.jsonl + +# You can move the other OOD test splits as well, after you have run the evaluation +mv $orig_dir/test_web/scores.jsonl $new_dir/test_web.jsonl +mv $orig_dir/test_geo/scores.jsonl $new_dir/test_geo.jsonl +mv $orig_dir/test_cat/scores.jsonl $new_dir/test_cat.jsonl +mv $orig_dir/test_vis/scores.jsonl $new_dir/test_vis.jsonl +``` + +Alternatively, you can also update `config.yml` to save the results in the correct directory, by overriding `candidates`: +```yaml +# ... +candidates: + # ... + model: "sentence-transformers/all-MiniLM-L6-v2" + path: ${project_dir}/results/${project_name}/${model.name}/${eval.split} +``` + +### Action Model + +#### Train LLaMA + +You can train the model by running the following command (it will automatically use the hydra config from `conf/`): + +```bash +export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use + +# Finetune 1.3b variant +python -m llama.train +variant="ft_1.3b" + +# Finetune 2.7b variant +python -m llama.train +variant="ft_2.7b" + +# For 7b, you will need to use fsdp in accelerate to train on 4 GPUs with 48GB VRAM +export CUDA_VISIBLE_DEVICES="0,1,2,3" +accelerate launch --use_fsdp --config_file llama/accelerate/fsdp_7b.yaml -m llama.train +variant="ft_7b" + +# For LLaMA-3-8b-Instruct, you will need to use fsdp in accelerate to train on 4 GPUs with 48GB VRAM +export CUDA_VISIBLE_DEVICES="4,5,6,7" +accelerate launch --use_fsdp --config_file llama/accelerate/fsdp_7b.yaml -m llama.train +variant="ft_llama3_8b_instruct" + +# For 13b, you need 6 GPUs with 48GB VRAM +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5" +accelerate launch --use_fsdp --config_file llama/accelerate/fsdp_13b.yaml -m llama.train +variant="ft_13b" +``` + +Results will be saved in `./results` and checkpoints in `./checkpoints`. + +#### Run LLaMA on Evaluation Splits + +You need to specify which `eval.split` you want to evaluate on. For example, to evaluate on the `iid` split, you can run the following command: + +```bash +export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use + +# On just one split +python -m llama.eval +variant="ft_1.3b" eval.split=valid + +# On multiple splits (e.g. test_iid, test_vis) +python -m llama.eval -m +variant="ft_2.7b" eval.split=test_iid,test_web,test_geo,test_cat,test_vis + +# Evaluating llama-3-8b-instruct on all splits +python -m llama.eval -m +variant="ft_llama3_8b_instruct" eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis +``` + +### Evaluation + +To run the evaluation metrics, you can use the following command (from this directory): + +```bash +python -m weblinx.eval -d results -b ./wl_data/demonstrations +``` + +In this case, `-b` is the base directory for the demonstrations, and `-d` is the directory containing the results (generated above by the `llama.eval` script). This will automatically run the evaluation metrics and save the results in the `results/aggregated_scores.json` directory. If you are only interested in the overall score for a split (e.g. `valid`), you can find look for the following entry in the aggregated score file (as an example): + +```json +// ... + { + "split": "valid", + "intent": "overall", + "metric": "overall", + "model_name": "princeton-nlp/Sheared-LLaMA-1.3B", + "project_name": "llama_ft", + "score": 0.21667765869744438, + "unconditional_score": 0.15307513104251605 + }, +// ... +``` + +Behind the scene, this will use the `weblinx.eval.auto_eval_and_save` function to run the evaluation metrics. If you want more control, you can also use that `weblinx.eval.auto_eval_and_save` function directly if you prefer; for an example, check out `weblinx/eval/__main__.py`. + +Note that it might be slow the first time you run, because it reads a lot of demonstrations and load millions of files. However, a demo-level cache is automatically created (see `./.cache/demonstrations`), so the next time you run it, it should be much faster. + +### More models + +#### Flan-T5 + +```bash +export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use + +# Base +python -m flan.train +python -m flan.eval -m eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis + + +# Large +python -m modeling.flan.train model.name=google/flan-t5-large +python -m modeling.flan.eval -m model.name=google/flan-t5-large eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis + +# XL +python -m modeling.flan.train +variant=ft_xl +python -m modeling.flan.eval +variant=ft_xl eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis +``` + +#### MindAct + +```bash +export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use + +# Base +python -m modeling.flan.train +variant=ft_mindact model.size=base +python -m modeling.flan.eval -m +variant=ft_mindact model.size=base eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis + +# Large +python -m modeling.flan.train +variant=ft_mindact model.size=large +python -m modeling.flan.eval -m +variant=ft_mindact model.size=large eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis + +# XL +python -m modeling.flan.train +variant=ft_mindact_xl +python -m modeling.flan.eval -m +variant=ft_mindact_xl eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis +``` + + +#### Pix2Act + +First, you will need to download the tff file for the Arial font (aka `Arial.TFF`) and place it at `${project_dir}/modeling/fonts/Arial.TTF`. On Windows, you can find it at `C:\windows\fonts\`. On Linux, you can find alternative fonts at `/usr/share/fonts/truetype/`. + +```bash +export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use + +# Base +python -m modeling.pix2act.train +python -m modeling.pix2act.eval eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis + +# Large +python -m modeling.pix2act.train +variant=ft_large +python -m modeling.pix2act.eval +variant=ft_large eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis +``` diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/dataset.py b/src/dataset.py new file mode 100644 index 0000000..3ca9861 --- /dev/null +++ b/src/dataset.py @@ -0,0 +1,9 @@ +from huggingface_hub import snapshot_download + +snapshot_download( + repo_id="McGill-NLP/WebLINX-full", + repo_type="dataset", + local_dir="./wl_data", + ignore_patterns=["**/bboxes/", "**/pages/", "**/screenshots/", "**/video.mp4"], + resume_download=True, +) diff --git a/src/llama/__init__.py b/src/llama/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/llama/accelerate/fsdp_13b.yaml b/src/llama/accelerate/fsdp_13b.yaml new file mode 100644 index 0000000..afbbc1d --- /dev/null +++ b/src/llama/accelerate/fsdp_13b.yaml @@ -0,0 +1,28 @@ +# Useful: https://huggingface.co/docs/accelerate/main/en/usage_guides/fsdp +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: 1 + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sync_module_states: true + # Set fsdp_use_orig_params=true if using peft: + # https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019 + fsdp_use_orig_params: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 6 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/src/llama/accelerate/fsdp_7b.yaml b/src/llama/accelerate/fsdp_7b.yaml new file mode 100644 index 0000000..488cef9 --- /dev/null +++ b/src/llama/accelerate/fsdp_7b.yaml @@ -0,0 +1,28 @@ +# Useful: https://huggingface.co/docs/accelerate/main/en/usage_guides/fsdp +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: 1 + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sync_module_states: true + # Set fsdp_use_orig_params=true if using peft: + # https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019 + fsdp_use_orig_params: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/src/llama/conf/config.yaml b/src/llama/conf/config.yaml new file mode 100644 index 0000000..816163b --- /dev/null +++ b/src/llama/conf/config.yaml @@ -0,0 +1,65 @@ +project_dir: ${oc.env:PROJECT_DIR} +seed: 123 +project_name: llama + +data: + num_proc: 8 + split_path: ${project_dir}/wl_data/splits.json + base_dir: ${project_dir}/wl_data/demonstrations/ + +train: + split: train + num_epochs: 3 + learning_rate: 5e-5 + batch_size_per_device: 12 + gradient_accumulation_steps: 2 + dataloader_num_workers: 8 + gradient_checkpointing: True + use_accelerator_device_map: False + use_auto_device_map: True + weight_decay: 0 + warmup_steps: 0 + warmup_ratio: 0 + scheduler: linear + optim: adamw_torch + +eval: + split: valid + batch_size_per_device: 2 + result_dir: ${project_dir}/results/${project_name}/${eval.split}/${model.name} + load_from_save_dir: False # If True, load from model.save_dir instead of model.name + +model: + name: meta-llama/Llama-2-7b-chat-hf + tokenizer: ${model.name} + template_tokenizer: ${model.tokenizer} + max_inp_len: null + max_out_len: 256 + use_rope: True + use_flash_attention_2: True + save_dir: ${project_dir}/checkpoints/${project_name}/${model.name} + +candidates: + k: 10 + model: McGill-NLP/MiniLM-L6-dmr # unused but potentially useful + project_name: dmr # unused but potentially useful + split: ${eval.split} + train_path: ${project_dir}/wl_data/candidates/train.jsonl + path: ${project_dir}/wl_data/candidates/${candidates.split}.jsonl + +huggingface: + token: ${oc.env:HF_TOKEN} + +wandb: + project: ${oc.env:WANDB_PROJECT} + key: ${oc.env:WANDB_API_KEY} + +hydra: + run: + dir: ${project_dir}/logs/${project_name}/${hydra.job.name}/${now:%Y-%m-%d-%H:%M:%S} + # Use the same for sweep's subdir + sweep: + dir: ${hydra.run.dir} + job: + chdir: False + verbose: INFO diff --git a/src/llama/conf/variant/ft_1.3b.yaml b/src/llama/conf/variant/ft_1.3b.yaml new file mode 100644 index 0000000..8fbd1ec --- /dev/null +++ b/src/llama/conf/variant/ft_1.3b.yaml @@ -0,0 +1,11 @@ +# @package _global_ +project_name: llama_ft + +model: + # This is meant to be run on 1 gpu with 48GB+ memory (e.g., A6000) + use_flash_attention_2: True + name: princeton-nlp/Sheared-LLaMA-1.3B + +eval: + batch_size_per_device: 8 + load_from_save_dir: True \ No newline at end of file diff --git a/src/llama/conf/variant/ft_13b.yaml b/src/llama/conf/variant/ft_13b.yaml new file mode 100644 index 0000000..84ccf17 --- /dev/null +++ b/src/llama/conf/variant/ft_13b.yaml @@ -0,0 +1,18 @@ +# @package _global_ +project_name: llama_ft + +model: + # This is meant to be run on 6 gpus with 48GB+ memory (e.g., A6000) + use_flash_attention_2: True + name: meta-llama/Llama-2-13b-chat-hf + +train: + # 6 (# gpus) * 3 (accum steps) * 1 (bsize) = 18 (batch size) + batch_size_per_device: 1 + gradient_accumulation_steps: 3 + use_accelerator_device_map: True + use_auto_device_map: False + +eval: + batch_size_per_device: 4 + load_from_save_dir: True \ No newline at end of file diff --git a/src/llama/conf/variant/ft_2.7b.yaml b/src/llama/conf/variant/ft_2.7b.yaml new file mode 100644 index 0000000..45829fc --- /dev/null +++ b/src/llama/conf/variant/ft_2.7b.yaml @@ -0,0 +1,11 @@ +# @package _global_ +project_name: llama_ft + +model: + # This is meant to be run on 1 gpu with 48GB+ memory (e.g., A6000) + use_flash_attention_2: True + name: princeton-nlp/Sheared-LLaMA-2.7B + +eval: + batch_size_per_device: 8 + load_from_save_dir: True \ No newline at end of file diff --git a/src/llama/conf/variant/ft_7b.yaml b/src/llama/conf/variant/ft_7b.yaml new file mode 100644 index 0000000..d29509c --- /dev/null +++ b/src/llama/conf/variant/ft_7b.yaml @@ -0,0 +1,18 @@ +# @package _global_ +project_name: llama_ft + +model: + # This is meant to be run on 4 gpus with 48GB+ memory (e.g., A6000) + use_flash_attention_2: True + name: meta-llama/Llama-2-7b-chat-hf + +train: + # 4 (# gpus) * 4 (accum steps) * 1 (bsize) = 16 (batch size) + batch_size_per_device: 4 + gradient_accumulation_steps: 1 + use_accelerator_device_map: True + use_auto_device_map: False + +eval: + batch_size_per_device: 8 + load_from_save_dir: True \ No newline at end of file diff --git a/src/llama/conf/variant/ft_llama3_8b_instruct.yaml b/src/llama/conf/variant/ft_llama3_8b_instruct.yaml new file mode 100644 index 0000000..9703619 --- /dev/null +++ b/src/llama/conf/variant/ft_llama3_8b_instruct.yaml @@ -0,0 +1,18 @@ +# @package _global_ +project_name: llama_ft # TODO: Change this to your project name + +model: + # This is meant to be run on 4 gpus with 48GB+ memory (e.g., A6000) + use_flash_attention_2: True + name: meta-llama/Meta-Llama-3-8B-Instruct + +train: + # 4 (# gpus) * 4 (accum steps) * 1 (bsize) = 16 (batch size) + batch_size_per_device: 4 + gradient_accumulation_steps: 1 + use_accelerator_device_map: True + use_auto_device_map: False + +eval: + batch_size_per_device: 8 + load_from_save_dir: True \ No newline at end of file diff --git a/src/llama/eval.py b/src/llama/eval.py new file mode 100644 index 0000000..e08ac77 --- /dev/null +++ b/src/llama/eval.py @@ -0,0 +1,155 @@ +import json +import logging +from functools import partial +from pathlib import Path +from typing import Any + +import hydra +import torch +from omegaconf import OmegaConf +from tqdm import tqdm +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerFast, + pipeline, +) +from transformers.pipelines.pt_utils import KeyDataset +from weblinx import Demonstration +from weblinx.processing import load_candidate_elements +from weblinx.processing.prompt import ( + build_input_records_from_selected_turns, + select_turns_and_candidates_for_prompts, +) +from weblinx.utils import load_demo_names_in_split +from weblinx.utils.hydra import save_path_to_hydra_logs + +from .processing import ( + build_formatter_for_multichoice, + build_prompt_records_for_llama_truncated, + insert_formatted_chat_into_records, +) + + +@hydra.main(version_base=None, config_path="conf", config_name="config") +def main(cfg) -> None: + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + model_save_dir = Path(cfg.model.save_dir).expanduser() + model_save_dir.mkdir(exist_ok=True, parents=True) + result_dir = Path(cfg.eval.result_dir).expanduser() + result_dir.mkdir(parents=True, exist_ok=True) + logger.info(OmegaConf.to_yaml(cfg)) + + tokenizer = AutoTokenizer.from_pretrained( + cfg.model.tokenizer, + add_eos_token=True, + padding_side="right", + trust_remote_code=True, + ) + tokenizer.pad_token = tokenizer.unk_token + + model_kwargs: dict[str, Any] = { + "device_map": "auto", + "torch_dtype": torch.bfloat16, + "rope_scaling": {"type": "dynamic", "factor": 2.0} if cfg.model.use_rope else None, + "attn_implementation": "flash_attention_2" if cfg.model.use_flash_attention_2 else None, + } + load_model_name = str(model_save_dir) if cfg.eval.get("load_from_save_dir", False) is True else cfg.model.name + model = AutoModelForCausalLM.from_pretrained(load_model_name, **model_kwargs) + + input_records = build_input_records_texts(cfg, tokenizer) + + evaluate(cfg, model, tokenizer, input_records, result_dir) + + +def build_input_records_texts(cfg, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast) -> list[dict[str, Any]]: + split_path = Path(cfg.data.split_path).expanduser() + + demo_names: list[str] = load_demo_names_in_split(split_path, split=cfg.eval.split) + demos = [Demonstration(demo_name, base_dir=cfg.data.base_dir) for demo_name in demo_names] + candidates = load_candidate_elements(path=cfg.candidates.path) + + format_intent = build_formatter_for_multichoice() + build_prompt_records_fn = partial( + build_prompt_records_for_llama_truncated, + format_intent=format_intent, + tokenizer=tokenizer, # type: ignore # noqa: PGH003 + ) + + selected_turns: list[dict[str, Any]] = select_turns_and_candidates_for_prompts( + demos=demos, + candidates=candidates, + num_candidates=cfg.candidates.k, + ) + + input_records = build_input_records_from_selected_turns( + selected_turns=selected_turns, + format_intent=format_intent, + build_prompt_records_fn=build_prompt_records_fn, + format_prompt_records_fn=None, + ) + + template_tokenizer = AutoTokenizer.from_pretrained(cfg.model.template_tokenizer) + input_records = insert_formatted_chat_into_records( + input_records, + template_tokenizer, + include_output_target=False, + ) + + return input_records + + +def evaluate( + cfg, + model, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + input_records: list[dict[str, Any]], + result_dir: Path, +) -> None: + dset = KeyDataset(input_records, key="text") # type: ignore # noqa: PGH003 + pipe = pipeline( + "text-generation", + model=model, + tokenizer=tokenizer, + torch_dtype=torch.bfloat16, + ) + pipe_kwargs = { + "batch_size": cfg.eval.batch_size_per_device, + "pad_token_id": tokenizer.unk_token_id, + "max_new_tokens": cfg.model.max_out_len, + "return_full_text": False, + } + + results = [] + with torch.amp.autocast("cuda", dtype=torch.bfloat16): # type: ignore # noqa: PGH003 + pbar = tqdm( + pipe(dset, **pipe_kwargs), + desc="Generating outputs", + total=len(dset), + ) + for i, out in enumerate(pbar): + rec = input_records[i] + generated_text = out[0]["generated_text"] + result = { + "demo_name": rec["demo_name"], + "turn_index": rec["turn_index"], + "prompt": rec["prompt"], + "text": rec["text"], + "output_predicted": generated_text, + "output_target": rec["output_target"], + "output_target_dict": rec["output_target_dict"], + } + + results.append(result) + + with Path.open(result_dir.joinpath("results.json"), "w") as f: + json.dump(results, f, indent=2) + + save_path_to_hydra_logs(save_dir=result_dir) + + +if __name__ == "__main__": + main() diff --git a/src/llama/processing.py b/src/llama/processing.py new file mode 100644 index 0000000..af698a4 --- /dev/null +++ b/src/llama/processing.py @@ -0,0 +1,324 @@ +"""The processing module contains functions to process the data for the LLAMA task.""" + +from collections.abc import Callable +from copy import deepcopy +from functools import partial + +import lxml.html +import weblinx.utils.format as wlf +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from weblinx import Replay, Turn +from weblinx.processing.dom import clean_and_prune_tree +from weblinx.processing.prompt import ( + find_turns_with_instructor_chat, + format_candidates, + format_utterances, + format_utterances_truncated, + get_speaker, + multi_attempt_format_prev_turns_truncated, +) +from weblinx.processing.truncation import ( + multi_attempt_truncate_cands_turn, + multi_attempt_truncate_dom_tree, +) + + +def build_formatter_for_multichoice() -> Callable: + """Build a formatter for the multichoice task.""" + format_click = partial(wlf.format_click, formatters=(wlf.format_uid,)) + format_text_input = partial( + wlf.format_text_input, + formatters=( + partial(wlf.format_arg_item, name="text", max_length=200), + wlf.format_uid, + ), + ) + format_change = partial( + wlf.format_change, + formatters=( + partial(wlf.format_arg_item, name="value", max_length=200), + wlf.format_uid, + ), + ) + format_submit = partial(wlf.format_submit, formatters=(wlf.format_uid,)) + format_load = partial( + wlf.format_load, + include_transition=False, + include_timestamp=False, + max_length=200, + ) + format_scroll = partial(wlf.format_scroll, include_timestamp=False) + + format_say = partial(wlf.format_say, include_timestamp=False) + + format_intent_auto = partial( + wlf.format_intent_automatically, + format_change=format_change, + format_click=format_click, + format_load=format_load, + format_say=format_say, + format_scroll=format_scroll, + format_submit=format_submit, + format_text_input=format_text_input, + ) + + return format_intent_auto + + +def get_system_prompt_template_for_llama_mc_concise() -> str: + """Return a system prompt template for the LLAMA task.""" + sys_prompt_template = ( + "You are an AI assistant with a deep understanding of HTML " + "and you must predict actions based on a user request, which will be executed. " + "Use one of the following, replacing [] with an appropriate value: " + "change(value=[str], uid=[str]) ; " + "click(uid=[str]) ; " + "load(url=[str]) ; " + 'say(speaker="navigator", utterance=[str]) ; ' + "scroll(x=[int], y=[int]) ; " + "submit(uid=[str]) ;" + "text_input(text=[str], uid=[str]) ;\n" + "The user's first and last {num_utterances} utterances are: " + "{utterance_context} ;\n" + "Viewport size: {height}h x {width}w ;\n" + "Only the last {num_prev_turns} turns are provided." + ) + + return sys_prompt_template + + +def get_candidate_prompt_template_for_llama() -> str: + """Return a candidate prompt template for the LLAMA task.""" + return "Here are the top candidates for this turn: {candidate_str}\n" + + +def get_final_user_message() -> str: + """Return a final user message for the LLAMA task.""" + return ( + "Please select the best action using the correct format, " + "do not provide any other information or explanation." + ) + + +def merge_prev_turns(prev_turns_text_list: list[str], final_user_message: str) -> list[dict[str, str]]: + """Merge previous turns into a single turn for the LLAMA task.""" + prev_turns_merged: list[dict[str, str]] = [] + + # Merge turns from the same role + for i, turn_text in enumerate(prev_turns_text_list): + role = get_speaker( + turn_text, + instructor_name="user", + navigator_name="assistant", + default_name="unknown", + ) + + if i > 0 and prev_turns_merged[-1]["role"] == role: + prev_turns_merged[-1]["content"] += " " + turn_text + else: + prev_turns_merged.append({"role": role, "content": turn_text}) + + if len(prev_turns_merged) > 0 and prev_turns_merged[-1]["role"] == "user": + prev_turns_merged[-1]["content"] += " " + final_user_message + else: + prev_turns_merged.append({"role": "user", "content": final_user_message}) + + return prev_turns_merged + + +def build_prompt_records_for_llama_truncated( # noqa: PLR0913 + replay: Replay, + turn: Turn, + format_intent, + tokenizer: PreTrainedTokenizer, + cands_turn=None, + num_utterances: int = 5, + num_prev_turns: int = 5, + system_prompt_template=None, + candidate_prompt_template=None, + final_user_message=None, + include_html=True, + format_candidates_fn=partial( # noqa: B008 + format_candidates, + max_char_len=None, # type: ignore # noqa: PGH003 + use_uid_as_rank=True, + ), + merge_prev_turns_fn=merge_prev_turns, + format_output_dict_fn: Callable = partial( # noqa: B008 + wlf.format_output_dictionary, + function_key="intent", + ), + max_html_tokens: int = 700, + max_utterance_tokens: int = 40 * 5, + max_prev_turns_tokens: int = 50 * 5, + max_candidates_tokens: int = 65 * 10, + add_unused_len_to_cands: bool = True, + allow_iterative_reduction: bool = False, + use_tokenizer_template: bool = False, + template_tokenizer=None, + parser=None, +) -> list[dict[str, str]]: + """Parameters + + ---------- + ... + allow_iterative_reduction : bool + This arg is only relevant when truncate_at_center is used behind the scene (e.g. for + multi_attempt_format_prev_turns_truncated or multi_attempt_truncate_dom_tree). If True, + then we will allow the iterative reduction to continue until the max_tokens is reached. + This is useful when the tokenizer output does not necessarily decrease when we remove + tokens from the input. For example, if we remove a token that is part of a word, but + the updated text is retokenized to the same number of tokens, then we will continue + to remove tokens until we reach the max_tokens limit. + """ + if system_prompt_template is None: + system_prompt_template = get_system_prompt_template_for_llama_mc_concise() + + if candidate_prompt_template is None: + candidate_prompt_template = get_candidate_prompt_template_for_llama() + + if final_user_message is None: + final_user_message = get_final_user_message() + + instructor_chat_turns = find_turns_with_instructor_chat( + replay, + turn, + num_prev_turns=num_prev_turns, + ) + utterance_context = format_utterances_truncated( + instructor_chat_turns, + tokenizer=tokenizer, + max_tokens=max_utterance_tokens, + num_utterances=num_utterances, + format_utterances_fn=format_utterances, + allow_iterative_reduction=allow_iterative_reduction, + ) + + prev_turns_text_list = multi_attempt_format_prev_turns_truncated( + replay=replay, + turn=turn, + format_intent=partial(format_intent, return_as=dict), + tokenizer=tokenizer, + num_prev_turns=num_prev_turns, + # turn_sep=None, # output list + max_tokens=max_prev_turns_tokens, + max_attempts=5, + format_output_dict_fn=format_output_dict_fn, + warn_after_attempts=False, + allow_iterative_reduction=allow_iterative_reduction, + ) + + prev_turns_merged = merge_prev_turns_fn( + prev_turns_text_list=prev_turns_text_list, + final_user_message=final_user_message, + ) + + sys_prompt = system_prompt_template.format( + num_utterances=num_utterances - 1, # 1 less since we add the first utterance + utterance_context=utterance_context, + height=turn.viewport_height, + width=turn.viewport_width, + num_prev_turns=num_prev_turns, + ) + + if include_html and turn.html not in ["", None] and cands_turn is not None: + dom_tree_raw = lxml.html.fromstring(turn.html, parser=parser) + dom_tree_pruned = clean_and_prune_tree(dom_tree_raw, cands_turn=cands_turn) + trunc = multi_attempt_truncate_dom_tree( + dom_tree=dom_tree_pruned, + tokenizer=tokenizer, + max_tokens=max_html_tokens, + warn_after_attempts=False, + allow_iterative_reduction=allow_iterative_reduction, + ) + html = trunc["tree_repr"] + sys_prompt = html + sys_prompt + else: + html = "" + + if cands_turn is not None: + if add_unused_len_to_cands: + # Add the unused length to the candidates + num_html_tokens = len(tokenizer.tokenize(html)) + num_utter_tokens = len(tokenizer.tokenize(utterance_context)) # type: ignore # noqa: PGH003 + if use_tokenizer_template: + if template_tokenizer is None: + msg = "template_tokenizer must be provided when use_tokenizer_template is True." + raise ValueError(msg) + prev_turns_merged_copy = deepcopy(prev_turns_merged) + if prev_turns_merged[0]["role"] == "assistant": + # insert a dummy user turn + prev_turns_merged_copy.insert(0, {"role": "user", "content": ""}) + num_prev_turns_tokens = len( + template_tokenizer.apply_chat_template( + [{"role": "system", "content": ""}, *prev_turns_merged_copy], + tokenize=True, + ), + ) + else: + num_prev_turns_tokens = len( + tokenizer.tokenize(" ".join(prev_turns_text_list)), + ) + remain_html_tokens = max_html_tokens - num_html_tokens + remain_utter_tokens = max_utterance_tokens - num_utter_tokens + remain_prev_turns_tokens = max_prev_turns_tokens - num_prev_turns_tokens + remain_tokens = remain_html_tokens + remain_utter_tokens + remain_prev_turns_tokens + # Add the unused length to the max_candidates_tokens + max_candidates_tokens += remain_tokens + + cands_turn_trunc = multi_attempt_truncate_cands_turn( + cands_turn=cands_turn, + tokenizer=tokenizer, + max_tokens=max_candidates_tokens, + format_candidates_fn=format_candidates_fn, + warn_after_attempts=False, + allow_iterative_reduction=allow_iterative_reduction, + ) + cand_str = format_candidates_fn(cands_turn_trunc, max_char_len=None) # type: ignore # noqa: PGH003 + cand_prompt = candidate_prompt_template.format(candidate_str=cand_str) + sys_prompt += "\n" + cand_prompt + + return [{"role": "system", "content": sys_prompt}, *prev_turns_merged] + + +def __insert_empty_user_content_at_first(prompt: list) -> None: + """Given a list of dictionary representing the input prompt, insert an empty user content at the first position. + + after system content, only if it is not already a user content. This is done in place. + """ + if prompt[0]["role"] != "system": + msg = f"First prompt must be a system prompt. Got {prompt[0]['role']} instead." + raise ValueError(msg) + + if prompt[1]["role"] != "user": + prompt.insert(1, {"role": "user", "content": ""}) + + +def insert_formatted_chat_into_records( + records, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + *, + include_output_target: bool = True, +) -> list: + """Given a list of records, insert the formatted chat into the records. This is done in place. + + Note that we need a tokenizer's `apply_chat_template` method to be available. + """ + processed_records = deepcopy(records) + for i, record in enumerate(records): + __insert_empty_user_content_at_first(record["prompt"]) + + if include_output_target: + target = [{"role": "assistant", "content": record["output_target"]}] + combined = record["prompt"] + target + else: + combined = record["prompt"] + + text = tokenizer.apply_chat_template( + combined, + tokenize=False, + add_generation_prompt=False, + ) + processed_records[i]["text"] = text + + return processed_records diff --git a/src/llama/train.py b/src/llama/train.py new file mode 100644 index 0000000..6ce4716 --- /dev/null +++ b/src/llama/train.py @@ -0,0 +1,208 @@ +import json +import logging +from functools import partial +from pathlib import Path +from typing import Any + +import datasets +import huggingface_hub +import hydra +import torch +import wandb +from accelerate import Accelerator +from dotenv import load_dotenv +from omegaconf import OmegaConf +from peft import LoraConfig # type: ignore # noqa: PGH003 +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + PreTrainedTokenizer, + PreTrainedTokenizerFast, + TrainingArguments, +) +from trl import SFTTrainer +from weblinx import Demonstration +from weblinx.processing import load_candidate_elements +from weblinx.processing.prompt import ( + build_input_records_from_selected_turns, + select_turns_and_candidates_for_prompts, +) +from weblinx.utils import load_demo_names_in_split, set_seed +from weblinx.utils.hydra import save_path_to_hydra_logs + +from .processing import ( + build_formatter_for_multichoice, + build_prompt_records_for_llama_truncated, + insert_formatted_chat_into_records, +) + +load_dotenv() + + +@hydra.main(config_path="conf", config_name="config", version_base=None) +def main(cfg) -> None: + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + huggingface_hub.login(token=cfg.huggingface.token) + wandb.login(key=cfg.wandb.key) + + set_seed(cfg.seed) + model_save_dir = Path(cfg.model.save_dir).expanduser() + model_save_dir.mkdir(exist_ok=True, parents=True) + logger.info(OmegaConf.to_yaml(cfg)) + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + + tokenizer = AutoTokenizer.from_pretrained( + cfg.model.tokenizer, + add_eos_token=True, + padding_side="right", + trust_remote_code=True, + ) + tokenizer.pad_token = tokenizer.unk_token + + model_kwargs: dict[str, Any] = { + "device_map": {"": Accelerator().process_index} if cfg.train.use_accelerator_device_map else "auto", + "torch_dtype": torch.bfloat16, + "use_cache": False, + "attn_implementation": "flash_attention_2" if cfg.model.use_flash_attention_2 else None, + "quantization_config": bnb_config, + } + model = AutoModelForCausalLM.from_pretrained(cfg.model.name, **model_kwargs) + + if (model_save_dir.joinpath("input_records_trunc.json")).exists(): + with Path.open(model_save_dir.joinpath("input_records_trunc.json"), "r") as f: + input_records = json.load(f) + input_records_texts = [{"text": record["text"]} for record in input_records] + else: + input_records_texts = build_input_records_texts(cfg, tokenizer) + + train(cfg, model, tokenizer, input_records_texts) + + +def build_input_records_texts(cfg, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast) -> list[dict[str, Any]]: + split_path = Path(cfg.data.split_path).expanduser() + + demo_names: list[str] = load_demo_names_in_split(split_path, split=cfg.train.split) + demos = [Demonstration(demo_name, base_dir=cfg.data.base_dir) for demo_name in demo_names] + candidates = load_candidate_elements(path=cfg.candidates.train_path) + + format_intent = build_formatter_for_multichoice() + build_prompt_records_fn = partial( + build_prompt_records_for_llama_truncated, + format_intent=format_intent, + tokenizer=tokenizer, # type: ignore # noqa: PGH003 + ) + + selected_turns: list[dict[str, Any]] = select_turns_and_candidates_for_prompts( + demos=demos, + candidates=candidates, + num_candidates=cfg.candidates.k, + ) + + input_records = build_input_records_from_selected_turns( + selected_turns=selected_turns, + format_intent=format_intent, + build_prompt_records_fn=build_prompt_records_fn, + format_prompt_records_fn=None, + ) + + template_tokenizer = AutoTokenizer.from_pretrained(cfg.model.template_tokenizer) + input_records = insert_formatted_chat_into_records( + input_records, + template_tokenizer, + include_output_target=True, + ) + + with Path.open(Path(cfg.model.save_dir).expanduser().joinpath("input_records_trunc.json"), "w") as f: + json.dump(input_records, f, indent=2) + + input_records_texts = [{"text": record["text"]} for record in input_records] + + return input_records_texts + + +def train( + cfg, + model, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + input_records_texts: list[dict[str, Any]], +) -> None: + model_save_dir = Path(cfg.model.save_dir).expanduser() + + peft_config = LoraConfig( + r=256, + lora_alpha=256, + # r=640, + # lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + target_modules=[ + # "q_proj", + # "k_proj", + # "v_proj", + # "o_proj", + "gate_proj", + "up_proj", + "down_proj", + # "embed_tokens", + # "lm_head", + ], + ) + + training_args = TrainingArguments( + output_dir=str(model_save_dir), + optim=cfg.train.optim, + learning_rate=cfg.train.learning_rate, + num_train_epochs=cfg.train.num_epochs, + per_device_train_batch_size=cfg.train.batch_size_per_device, + gradient_accumulation_steps=cfg.train.gradient_accumulation_steps, + gradient_checkpointing=cfg.train.gradient_checkpointing, + weight_decay=cfg.train.weight_decay, + warmup_steps=cfg.train.warmup_steps, + warmup_ratio=cfg.train.warmup_ratio, + lr_scheduler_type=cfg.train.scheduler, + group_by_length=True, + save_steps=100, + save_strategy="steps", + logging_steps=10, + logging_strategy="steps", + logging_first_step=True, + prediction_loss_only=True, + bf16=True, + bf16_full_eval=True, + torch_compile=True, + report_to="wandb", + ) # type: ignore # noqa: PGH003 + + trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + args=training_args, # type: ignore # noqa: PGH003 + train_dataset=datasets.Dataset.from_list(input_records_texts), + # NOTE: max_seq_length and dataset_text_field are no longer supported in v1.0.0. + max_seq_length=model.config.max_position_embeddings, + dataset_text_field="text", + peft_config=peft_config, + ) + + wandb.init(project=cfg.wandb.project) + + trainer.train() # type: ignore # noqa: PGH003 + + trainer.save_model(str(model_save_dir)) + tokenizer.save_pretrained(model_save_dir) + trainer.state.save_to_json(str(model_save_dir.joinpath("trainer_state.json"))) + save_path_to_hydra_logs(save_dir=model_save_dir) + + +if __name__ == "__main__": + main() diff --git a/webnavix.code-workspace b/webnavix.code-workspace new file mode 100644 index 0000000..a26ff12 --- /dev/null +++ b/webnavix.code-workspace @@ -0,0 +1,8 @@ +{ + "folders": [ + { + "name": "webnavix", + "path": "." + } + ] +}