diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..27baf19 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,20 @@ +{ + "name": "webnavix", + "workspaceFolder": "/workspaces/webnavix/", + "dockerComposeFile": ["../docker/docker-compose.development.yaml"], + "service": "app", + "customizations": { + "vscode": { + "extensions": [ + "adam-bender.commit-message-editor", + "charliermarsh.ruff", + "eamodio.gitlens", + "EditorConfig.EditorConfig", + "esbenp.prettier-vscode", + "ms-python.python", + "tamasfe.even-better-toml", + "VisualStudioExptTeam.vscodeintellicode" + ] + } + } +} diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..67ddee0 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,27 @@ +# cache +**/__pycache__/ +**/.mypy_cache/ +**/.ruff_cache/ +**/*.egg-info/ + +# dataset +**/wl_data/ + +# debug +**/*log* + +# deliverable +**/build/ +**/dist/ +**/out/ + +# dependency +**/.venv/ + +# env file +**/.env* +!**/.env.example + +# misc +**/.DS_Store +**/*.pem diff --git a/.github/workflows/app-test.yaml b/.github/workflows/app-test.yaml new file mode 100644 index 0000000..96dc403 --- /dev/null +++ b/.github/workflows/app-test.yaml @@ -0,0 +1,19 @@ +name: app test + +on: push + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: checkout + uses: actions/checkout@v4 + + - name: setup rye + uses: eifinger/setup-rye@v4 + + - name: install dependencies + run: rye sync + + - name: check + run: rye check diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e87399b --- /dev/null +++ b/.gitignore @@ -0,0 +1,30 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# cache +__pycache__/ +.mypy_cache/ +.ruff_cache/ +*.egg-info/ + +# dataset +wl_data/ + +# debug +*log* + +# deliverable +build/ +checkpoints/ +dist/ +out/ + +# dependency +.venv/ + +# env file +.env* +!.env.example + +# misc +.DS_Store +*.pem diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..871f80a --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12.3 diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..13d3842 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,159 @@ +{ + "commit-message-editor.tokens": [ + { + "label": "Type", + "name": "type", + "type": "enum", + "description": "Type of changes.", + "combobox": true, + "options": [ + { + "label": "feat: โœจ", + "value": "feat: โœจ", + "description": "Implementation of new features." + }, + { + "label": "feat: ๐ŸŽˆ", + "value": "feat: ๐ŸŽˆ", + "description": "Repair of existing features." + }, + { + "label": "feat: โšฐ๏ธ", + "value": "feat: โšฐ๏ธ", + "description": "Deletion of features." + }, + { + "label": "fix: ๐Ÿ›", + "value": "fix: ๐Ÿ›", + "description": "Bug fixes." + }, + { + "label": "fix: ๐Ÿš‘๏ธ", + "value": "fix: ๐Ÿš‘๏ธ", + "description": "Critical bug fixes or major changes." + }, + { + "label": "doc: ๐Ÿ“", + "value": "doc: ๐Ÿ“", + "description": "Documentation changes." + }, + { + "label": "typo: ๐Ÿ–‹๏ธ", + "value": "typo: ๐Ÿ–‹๏ธ", + "description": "Typography changes." + }, + { + "label": "style: ๐Ÿ’„", + "value": "style: ๐Ÿ’„", + "description": "Style changes." + }, + { + "label": "refactor: โ™ป๏ธ", + "value": "refactor: โ™ป๏ธ", + "description": "Code formatting or refactoring." + }, + { + "label": "test: ๐Ÿงช", + "value": "test: ๐Ÿงช", + "description": "Test cases changes." + }, + { + "label": "ci: ๐Ÿฆบ", + "value": "ci: ๐Ÿฆบ", + "description": "CI changes." + }, + { + "label": "build: ๐Ÿ“ฆ๏ธ", + "value": "build: ๐Ÿ“ฆ๏ธ", + "description": "Build system or dependency changes." + }, + { + "label": "container: ๐Ÿณ", + "value": "container: ๐Ÿณ", + "description": "The Dockerfile changes." + }, + { + "label": "container: ๐Ÿ™", + "value": "container: ๐Ÿ™", + "description": "The docker-compose changes." + }, + { + "label": "chore: ๐Ÿ”ง", + "value": "chore: ๐Ÿ”ง", + "description": "Configuration changes." + }, + { + "label": "chore: ๐Ÿ”จ", + "value": "chore: ๐Ÿ”จ", + "description": "Development script changes." + }, + { + "label": "chore: ๐Ÿฑ", + "value": "chore: ๐Ÿฑ", + "description": "Assets changes." + }, + { + "label": "revert: โช๏ธ", + "value": "revert: โช๏ธ", + "description": "Reversion of changes." + }, + { + "label": "wip: ๐Ÿšง", + "value": "wip: ๐Ÿšง", + "description": "Changes that will be squashed." + }, + { + "label": "initial: ๐ŸŽ‰", + "value": "initial: ๐ŸŽ‰", + "description": "The first commit." + } + ] + }, + { + "label": "Scope", + "name": "scope", + "type": "text", + "description": "Scope of changes.", + "prefix": " (", + "suffix": ")" + }, + { + "label": "Short Description", + "name": "description", + "type": "text", + "description": "Commit summary.", + "prefix": " " + }, + { + "label": "Body", + "name": "body", + "type": "text", + "description": "Detailed description of commit.", + "maxLines": 10, + "multiline": true, + "lines": 5 + }, + { + "label": "Footer", + "name": "footer", + "description": "Description of disruptive changes or signature.", + "type": "text", + "multiline": true + } + ], + "commit-message-editor.dynamicTemplate": ["{type}{scope}{description}", "", "{body}", "", "{footer}"], + "commit-message-editor.staticTemplate": ["label: emoji (scope) short-description", "", "body", "", "footer"], + "commit-message-editor.view.defaultView": "form", + "editor.defaultFormatter": "esbenp.prettier-vscode", + "files.encoding": "utf8", + "files.eol": "\n", + "python.analysis.typeCheckingMode": "basic", + "python.defaultInterpreterPath": "/opt/rye/shims/python", + "[python]": { + "editor.codeActionsOnSave": { + "source.fixAll.ruff": "explicit", + "source.organizeImports.ruff": "explicit" + }, + "editor.defaultFormatter": "charliermarsh.ruff" + } +} diff --git a/README.md b/README.md new file mode 100644 index 0000000..0b407df --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# WebNavix diff --git a/docker/Dockerfile.development b/docker/Dockerfile.development new file mode 100644 index 0000000..54cfaf9 --- /dev/null +++ b/docker/Dockerfile.development @@ -0,0 +1,27 @@ +FROM python:3.12 + +SHELL ["/bin/bash", "-o", "pipefail", "-c"] + +# hadolint ignore=DL3008 +RUN apt-get update \ + && apt-get --no-install-recommends -y install git gnupg2 ca-certificates curl pipx \ + && pipx ensurepath \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists + +RUN curl -sSf https://rye.astral.sh/get | RYE_INSTALL_OPTION="--yes" bash \ + && echo "source '$HOME/.rye/env'" >> ~/.bashrc \ + && /root/.rye/shims/rye config --set-bool behavior.global-python=true \ + && /root/.rye/shims/rye config --set-bool behavior.use-uv=true + +ENV RYE_HOME="/opt/rye" +ENV PATH="$RYE_HOME/shims:$PATH" + +WORKDIR /workspaces/webnavix/ + +COPY ./.python-version ./pyproject.toml ./requirements* ./ +RUN /root/.rye/shims/rye pin "$(cat ./.python-version)" && /root/.rye/shims/rye sync + +COPY ./ ./ + +RUN if [ ! -d ./wl_data/ ]; then /root/.rye/shims/rye run python ./src/dataset.py; fi diff --git a/docker/docker-compose.development.yaml b/docker/docker-compose.development.yaml new file mode 100644 index 0000000..b6d0458 --- /dev/null +++ b/docker/docker-compose.development.yaml @@ -0,0 +1,13 @@ +services: + app: + container_name: webnavix + build: + context: ../ + dockerfile: ./docker/Dockerfile.development + volumes: + - type: bind + source: ../ + target: /workspaces/webnavix/ + environment: + WEBLINX_PROJECT_DIR: /workspaces/webnavix/ + tty: true diff --git a/lefthook.yaml b/lefthook.yaml new file mode 100644 index 0000000..9de67fa --- /dev/null +++ b/lefthook.yaml @@ -0,0 +1,9 @@ +pre-commit: + parallel: true + commands: + check-py: + glob: "*.*{py}*" + run: ruff check --fix {staged_files} + format-py: + glob: "*.*{py}*" + run: ruff format --fix {staged_files} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..930abd8 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,48 @@ +[project] +name = "webnavix" +version = "0.1.0" +description = "Add your description here" +authors = [ + { name = "shio", email = "85730998+dino3616@users.noreply.github.com" }, +] +dependencies = [ + "accelerate>=0.32.1", + "backoff>=2.2.1", + "bert-score>=0.3.13", + "bitsandbytes>=0.42.0", + "coloredlogs>=15.0.1", + "datasets>=2.20.0", + "huggingface-hub>=0.24.5", + "hydra-core>=1.3.2", + "lxml>=5.2.2", + "ninja>=1.11.1.1", + "numpy>=1.26.4", + "openai>=1.35.15", + "optimum>=1.21.2", + "packaging>=24.1", + "pandas>=2.2.2", + "peft>=0.11.1", + "pillow>=10.4.0", + "sacrebleu>=2.4.2", + "sentence-transformers>=3.0.1", + "setuptools>=71.1.0", + "tensorboardx>=2.6.2.2", + "tiktoken>=0.7.0", + "torch>=2.3.1", + "tqdm>=4.66.4", + "transformers>=4.42.4", + "trl>=0.9.6", + "weblinx>=0.3.0", +] +readme = "README.md" +requires-python = ">= 3.12" + +[tool.rye] +managed = true +dev-dependencies = ["ruff>=0.5.3", "lefthook>=0.1.2"] + +[tool.rye.scripts] +check = { chain = ["lint", "lint:type", "fmt"] } +"lint" = "ruff check ./ --diff" +"lint:type" = "mypy ./ --explicit-package-bases" +"fmt" = "ruff fmt ./" diff --git a/requirements-dev.lock b/requirements-dev.lock new file mode 100644 index 0000000..a2a816a --- /dev/null +++ b/requirements-dev.lock @@ -0,0 +1,307 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false +# generate-hashes: false +# universal: false + +-e file:. +accelerate==0.32.1 + # via peft + # via trl + # via webnavix +aiohttp==3.9.5 + # via datasets + # via fsspec +aiosignal==1.3.1 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +antlr4-python3-runtime==4.9.3 + # via hydra-core + # via omegaconf +anyio==4.4.0 + # via httpx + # via openai +attrs==23.2.0 + # via aiohttp +backoff==2.2.1 + # via webnavix +bert-score==0.3.13 + # via webnavix +bitsandbytes==0.42.0 + # via webnavix +certifi==2024.7.4 + # via httpcore + # via httpx + # via requests +charset-normalizer==3.3.2 + # via requests +colorama==0.4.6 + # via sacrebleu +coloredlogs==15.0.1 + # via optimum + # via webnavix +contourpy==1.2.1 + # via matplotlib +cycler==0.12.1 + # via matplotlib +datasets==2.20.0 + # via optimum + # via trl + # via webnavix +dill==0.3.8 + # via datasets + # via multiprocess +distro==1.9.0 + # via openai +docstring-parser==0.16 + # via tyro +filelock==3.15.4 + # via datasets + # via huggingface-hub + # via torch + # via transformers +fonttools==4.53.1 + # via matplotlib +frozenlist==1.4.1 + # via aiohttp + # via aiosignal +fsspec==2024.5.0 + # via datasets + # via huggingface-hub + # via torch +h11==0.14.0 + # via httpcore +httpcore==1.0.5 + # via httpx +httpx==0.27.0 + # via openai +huggingface-hub==0.24.5 + # via accelerate + # via datasets + # via optimum + # via peft + # via sentence-transformers + # via tokenizers + # via transformers + # via webnavix +humanfriendly==10.0 + # via coloredlogs +hydra-core==1.3.2 + # via webnavix +idna==3.7 + # via anyio + # via httpx + # via requests + # via yarl +jinja2==3.1.4 + # via torch +joblib==1.4.2 + # via scikit-learn +kiwisolver==1.4.5 + # via matplotlib +lefthook==0.1.2 +lxml==5.2.2 + # via sacrebleu + # via webnavix +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.5 + # via jinja2 +matplotlib==3.9.1 + # via bert-score +mdurl==0.1.2 + # via markdown-it-py +mpmath==1.3.0 + # via sympy +multidict==6.0.5 + # via aiohttp + # via yarl +multiprocess==0.70.16 + # via datasets +networkx==3.3 + # via torch +ninja==1.11.1.1 + # via webnavix +numpy==1.26.4 + # via accelerate + # via bert-score + # via contourpy + # via datasets + # via matplotlib + # via optimum + # via pandas + # via peft + # via pyarrow + # via sacrebleu + # via scikit-learn + # via scipy + # via sentence-transformers + # via tensorboardx + # via transformers + # via trl + # via webnavix +omegaconf==2.3.0 + # via hydra-core +openai==1.35.15 + # via webnavix +optimum==1.21.2 + # via webnavix +packaging==24.1 + # via accelerate + # via bert-score + # via datasets + # via huggingface-hub + # via hydra-core + # via matplotlib + # via optimum + # via peft + # via tensorboardx + # via transformers + # via webnavix +pandas==2.2.2 + # via bert-score + # via datasets + # via webnavix +peft==0.11.1 + # via webnavix +pillow==10.4.0 + # via matplotlib + # via sentence-transformers + # via webnavix +portalocker==2.10.1 + # via sacrebleu +protobuf==5.27.2 + # via tensorboardx + # via transformers +psutil==6.0.0 + # via accelerate + # via peft +pyarrow==17.0.0 + # via datasets +pyarrow-hotfix==0.6 + # via datasets +pydantic==2.8.2 + # via openai +pydantic-core==2.20.1 + # via pydantic +pygments==2.18.0 + # via rich +pyparsing==3.1.2 + # via matplotlib +python-dateutil==2.9.0.post0 + # via matplotlib + # via pandas +pytz==2024.1 + # via pandas +pyyaml==6.0.1 + # via accelerate + # via datasets + # via huggingface-hub + # via omegaconf + # via peft + # via transformers +regex==2024.5.15 + # via sacrebleu + # via tiktoken + # via transformers +requests==2.32.3 + # via bert-score + # via datasets + # via huggingface-hub + # via tiktoken + # via transformers +rich==13.7.1 + # via tyro +ruff==0.5.3 +sacrebleu==2.4.2 + # via webnavix +safetensors==0.4.3 + # via accelerate + # via peft + # via transformers +scikit-learn==1.5.1 + # via sentence-transformers +scipy==1.14.0 + # via bitsandbytes + # via scikit-learn + # via sentence-transformers +sentence-transformers==3.0.1 + # via webnavix +sentencepiece==0.2.0 + # via transformers +setuptools==71.1.0 + # via torch + # via webnavix +shtab==1.7.1 + # via tyro +six==1.16.0 + # via python-dateutil +sniffio==1.3.1 + # via anyio + # via httpx + # via openai +sympy==1.13.1 + # via optimum + # via torch +tabulate==0.9.0 + # via sacrebleu +tensorboardx==2.6.2.2 + # via webnavix +threadpoolctl==3.5.0 + # via scikit-learn +tiktoken==0.7.0 + # via webnavix +tokenizers==0.19.1 + # via transformers +torch==2.4.0 + # via accelerate + # via bert-score + # via optimum + # via peft + # via sentence-transformers + # via trl + # via webnavix +tqdm==4.66.4 + # via bert-score + # via datasets + # via huggingface-hub + # via openai + # via peft + # via sentence-transformers + # via transformers + # via weblinx + # via webnavix +transformers==4.42.4 + # via bert-score + # via optimum + # via peft + # via sentence-transformers + # via trl + # via webnavix +trl==0.9.6 + # via webnavix +typing-extensions==4.12.2 + # via huggingface-hub + # via openai + # via pydantic + # via pydantic-core + # via torch + # via tyro +tyro==0.8.5 + # via trl +tzdata==2024.1 + # via pandas +urllib3==2.2.2 + # via requests +weblinx==0.3.0 + # via webnavix +xxhash==3.4.1 + # via datasets +yarl==1.9.4 + # via aiohttp diff --git a/requirements.lock b/requirements.lock new file mode 100644 index 0000000..665bfe3 --- /dev/null +++ b/requirements.lock @@ -0,0 +1,305 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false +# generate-hashes: false +# universal: false + +-e file:. +accelerate==0.32.1 + # via peft + # via trl + # via webnavix +aiohttp==3.9.5 + # via datasets + # via fsspec +aiosignal==1.3.1 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +antlr4-python3-runtime==4.9.3 + # via hydra-core + # via omegaconf +anyio==4.4.0 + # via httpx + # via openai +attrs==23.2.0 + # via aiohttp +backoff==2.2.1 + # via webnavix +bert-score==0.3.13 + # via webnavix +bitsandbytes==0.42.0 + # via webnavix +certifi==2024.7.4 + # via httpcore + # via httpx + # via requests +charset-normalizer==3.3.2 + # via requests +colorama==0.4.6 + # via sacrebleu +coloredlogs==15.0.1 + # via optimum + # via webnavix +contourpy==1.2.1 + # via matplotlib +cycler==0.12.1 + # via matplotlib +datasets==2.20.0 + # via optimum + # via trl + # via webnavix +dill==0.3.8 + # via datasets + # via multiprocess +distro==1.9.0 + # via openai +docstring-parser==0.16 + # via tyro +filelock==3.15.4 + # via datasets + # via huggingface-hub + # via torch + # via transformers +fonttools==4.53.1 + # via matplotlib +frozenlist==1.4.1 + # via aiohttp + # via aiosignal +fsspec==2024.5.0 + # via datasets + # via huggingface-hub + # via torch +h11==0.14.0 + # via httpcore +httpcore==1.0.5 + # via httpx +httpx==0.27.0 + # via openai +huggingface-hub==0.24.5 + # via accelerate + # via datasets + # via optimum + # via peft + # via sentence-transformers + # via tokenizers + # via transformers + # via webnavix +humanfriendly==10.0 + # via coloredlogs +hydra-core==1.3.2 + # via webnavix +idna==3.7 + # via anyio + # via httpx + # via requests + # via yarl +jinja2==3.1.4 + # via torch +joblib==1.4.2 + # via scikit-learn +kiwisolver==1.4.5 + # via matplotlib +lxml==5.2.2 + # via sacrebleu + # via webnavix +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.5 + # via jinja2 +matplotlib==3.9.1 + # via bert-score +mdurl==0.1.2 + # via markdown-it-py +mpmath==1.3.0 + # via sympy +multidict==6.0.5 + # via aiohttp + # via yarl +multiprocess==0.70.16 + # via datasets +networkx==3.3 + # via torch +ninja==1.11.1.1 + # via webnavix +numpy==1.26.4 + # via accelerate + # via bert-score + # via contourpy + # via datasets + # via matplotlib + # via optimum + # via pandas + # via peft + # via pyarrow + # via sacrebleu + # via scikit-learn + # via scipy + # via sentence-transformers + # via tensorboardx + # via transformers + # via trl + # via webnavix +omegaconf==2.3.0 + # via hydra-core +openai==1.35.15 + # via webnavix +optimum==1.21.2 + # via webnavix +packaging==24.1 + # via accelerate + # via bert-score + # via datasets + # via huggingface-hub + # via hydra-core + # via matplotlib + # via optimum + # via peft + # via tensorboardx + # via transformers + # via webnavix +pandas==2.2.2 + # via bert-score + # via datasets + # via webnavix +peft==0.11.1 + # via webnavix +pillow==10.4.0 + # via matplotlib + # via sentence-transformers + # via webnavix +portalocker==2.10.1 + # via sacrebleu +protobuf==5.27.2 + # via tensorboardx + # via transformers +psutil==6.0.0 + # via accelerate + # via peft +pyarrow==17.0.0 + # via datasets +pyarrow-hotfix==0.6 + # via datasets +pydantic==2.8.2 + # via openai +pydantic-core==2.20.1 + # via pydantic +pygments==2.18.0 + # via rich +pyparsing==3.1.2 + # via matplotlib +python-dateutil==2.9.0.post0 + # via matplotlib + # via pandas +pytz==2024.1 + # via pandas +pyyaml==6.0.1 + # via accelerate + # via datasets + # via huggingface-hub + # via omegaconf + # via peft + # via transformers +regex==2024.5.15 + # via sacrebleu + # via tiktoken + # via transformers +requests==2.32.3 + # via bert-score + # via datasets + # via huggingface-hub + # via tiktoken + # via transformers +rich==13.7.1 + # via tyro +sacrebleu==2.4.2 + # via webnavix +safetensors==0.4.3 + # via accelerate + # via peft + # via transformers +scikit-learn==1.5.1 + # via sentence-transformers +scipy==1.14.0 + # via bitsandbytes + # via scikit-learn + # via sentence-transformers +sentence-transformers==3.0.1 + # via webnavix +sentencepiece==0.2.0 + # via transformers +setuptools==71.1.0 + # via torch + # via webnavix +shtab==1.7.1 + # via tyro +six==1.16.0 + # via python-dateutil +sniffio==1.3.1 + # via anyio + # via httpx + # via openai +sympy==1.13.1 + # via optimum + # via torch +tabulate==0.9.0 + # via sacrebleu +tensorboardx==2.6.2.2 + # via webnavix +threadpoolctl==3.5.0 + # via scikit-learn +tiktoken==0.7.0 + # via webnavix +tokenizers==0.19.1 + # via transformers +torch==2.4.0 + # via accelerate + # via bert-score + # via optimum + # via peft + # via sentence-transformers + # via trl + # via webnavix +tqdm==4.66.4 + # via bert-score + # via datasets + # via huggingface-hub + # via openai + # via peft + # via sentence-transformers + # via transformers + # via weblinx + # via webnavix +transformers==4.42.4 + # via bert-score + # via optimum + # via peft + # via sentence-transformers + # via trl + # via webnavix +trl==0.9.6 + # via webnavix +typing-extensions==4.12.2 + # via huggingface-hub + # via openai + # via pydantic + # via pydantic-core + # via torch + # via tyro +tyro==0.8.5 + # via trl +tzdata==2024.1 + # via pandas +urllib3==2.2.2 + # via requests +weblinx==0.3.0 + # via webnavix +xxhash==3.4.1 + # via datasets +yarl==1.9.4 + # via aiohttp diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..8829541 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,87 @@ +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] + +# Same as Black. +line-length = 88 +indent-width = 4 + +# Assume Python 3.12 +target-version = "py312" + +[lint] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +select = ["ALL"] +ignore = [ + "ANN001", + "ANN201", + "C408", + "COM812", + "COM819", + "D100", + "D103", + "D104", + "D203", + "D213", + "D300", + "ERA001", + "E111", + "E114", + "E117", + "E501", + "FBT002", + "ISC001", + "ISC002", + "Q000", + "Q001", + "Q002", + "Q003", + "RET504", + "W191", +] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" diff --git a/src/README.md b/src/README.md new file mode 100644 index 0000000..8000bfb --- /dev/null +++ b/src/README.md @@ -0,0 +1,280 @@ +The following instructions assume you are running from this directory (you may need to `cd` to this directory). + +### Download Data + +First, you need to download the `splits.json` file containing information about all the splits, as well as the `train.jsonl` candidate selected by `McGill-NLP/MiniLM-L6-DMR`: + +```python +from huggingface_hub import snapshot_download + +# splits.json +snapshot_download( + repo_id="McGill-NLP/WebLINX-full", repo_type="dataset", allow_patterns="splits.json", local_dir="./wl_data/" +) + +# candidates files +snapshot_download( + repo_id="McGill-NLP/WebLINX-full", + repo_type="dataset", + allow_patterns="candidates/*.jsonl", + local_dir="./wl_data/" +) +``` + +Download the full dataset (warning: this will take a while): + +```python +from huggingface_hub import snapshot_download + +snapshot_download(repo_id="McGill-NLP/WebLINX-full", repo_type="dataset", local_dir="./wl_data/") +``` + +The default configs (`llama/conf/config.yml`) assume that the `train.jsonl` is located at `./wl_data/candidates/train.jsonl`. If you want to change the path, you need to modify the `config.yml` accordingly. + +### Set `WEBLINX_PROJECT_DIR` + +You need to set the `WEBLINX_PROJECT_DIR` environment variable to the root directory of the WebLINX project. For example, if you have the following directory structure: + +```bash +export WEBLINX_PROJECT_DIR=/path/to/the/modeling/directory/ + +# For example, if you are in the modeling directory, you can run: +export WEBLINX_PROJECT_DIR=$(pwd) +``` + +### Install Dependencies + +You need to install the dependencies by running the following command: + +```bash +pip install -r requirements.txt +``` + +However, due to `flash-attention` requiring `torch` to be pre-installed, it has to be install right after everything else has been installed: +```bash +pip install wheel +# Regular install +pip install "flash-attn>=2.3.0" +# IF you have limited RAM, you can try this: +MAX_JOBS=4 pip install "flash-attn>=2.3.0" --no-build-isolation +# If you have issues with nvcc, try this: +FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install "flash-attn>=2.3.0" --no-build-isolation +``` + +### Optional: Symbolic linking to `WebLINX-full` + +If you downloaded `WebLINX-full` data in a different location (e.g. different disk) from your `weblinx/modeling` directory, you might consider using symbolic link to avoid having to change the `config.yml` files. You should do something like: + +```bash +ln -s /location/of/your/full/data /location/of/project/weblinx/modeling/wl_data +``` + +For example, if your data is located at `/mnt/research/scratch/users/jdoe/WebLINX-full` but your cloned `weblinx` repository is at `~/dev/weblinx`, then you'd run: + +```bash +ln -s /mnt/research/scratch/users/jdoe/WebLINX-full ~/dev/weblinx/modeling/wl_data +``` + +Which corresponds to the `data.base_dir` specified in `config.yml`, which is `"${project_dir}/wl_data/demonstrations/"`. + +### Dense Markup Ranking (DMR) + +#### Train DMR + +You can train the model by running the following command (it will automatically use the hydra config from `conf/`): + +```bash +export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use + +# Finetune MiniLM-L6-DMR (Default) +python -m dmr.train + +# Finetune variant gte or bge +python -m dmr.train +variant=gte +python -m dmr.train +variant=bge +``` + +Results will be saved in `./results` and checkpoints in `./checkpoints`. + +#### Inference for DMR + +You need to specify which `eval.split` you want to evaluate on. For example, to evaluate on the `iid` split, you can run the following command: + +```bash +export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use + +# On just one +python -m dmr.eval eval.split=valid + +# On multiple splits (e.g. test_iid, test_vis) +python -m dmr.eval eval.split=test_iid,test_web,test_geo,test_cat,test_vis + +# Or for bge, gte +python -m dmr.eval +variant=gte eval.split=test_iid,test_web,test_geo,test_cat,test_vis +python -m dmr.eval +variant=bge eval.split=test_iid,test_web,test_geo,test_cat,test_vis +``` + +#### Moving generated DMR results to `wl_data/candidates` + +The `scores.jsonl` and `results.json` files will be saved at the `cfg.eval.result_dir` variable in `modeling/dmr/conf/config.yml`, which is by default `${project_dir}/results/${project_name}/${model.name}/${eval.split}`, which should by default resolve to `/path/to/weblinx/modeling/results/dmr/sentence-transformers/all-MiniLM-L6-v2/train` for the `train` split, `.../valid` for the valid split, etc. However, since the next steps assumes you have a directory like `wl_data/candidates/.json`, you need to manually move it. For example, you could run: + +```bash +# Change the following paths to match your setup +orig_dir="/path/to/weblinx/modeling/results/dmr/sentence-transformers/all-MiniLM-L6-v2" + +# This is the directory where the candidates are stored +new_dir="/path/to/wl_data/candidates" + +# You need to move the train split if you plan to use it for training the action model +mv $orig_dir/train/scores.jsonl $new_dir/train.jsonl + +# You can move valid and test IID splits as well +mv $orig_dir/valid/scores.jsonl $new_dir/valid.jsonl +mv $orig_dir/test_iid/scores.jsonl $new_dir/test_iid.jsonl + +# You can move the other OOD test splits as well, after you have run the evaluation +mv $orig_dir/test_web/scores.jsonl $new_dir/test_web.jsonl +mv $orig_dir/test_geo/scores.jsonl $new_dir/test_geo.jsonl +mv $orig_dir/test_cat/scores.jsonl $new_dir/test_cat.jsonl +mv $orig_dir/test_vis/scores.jsonl $new_dir/test_vis.jsonl +``` + +Alternatively, you can also update `config.yml` to save the results in the correct directory, by overriding `candidates`: +```yaml +# ... +candidates: + # ... + model: "sentence-transformers/all-MiniLM-L6-v2" + path: ${project_dir}/results/${project_name}/${model.name}/${eval.split} +``` + +### Action Model + +#### Train LLaMA + +You can train the model by running the following command (it will automatically use the hydra config from `conf/`): + +```bash +export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use + +# Finetune 1.3b variant +python -m llama.train +variant="ft_1.3b" + +# Finetune 2.7b variant +python -m llama.train +variant="ft_2.7b" + +# For 7b, you will need to use fsdp in accelerate to train on 4 GPUs with 48GB VRAM +export CUDA_VISIBLE_DEVICES="0,1,2,3" +accelerate launch --use_fsdp --config_file llama/accelerate/fsdp_7b.yaml -m llama.train +variant="ft_7b" + +# For LLaMA-3-8b-Instruct, you will need to use fsdp in accelerate to train on 4 GPUs with 48GB VRAM +export CUDA_VISIBLE_DEVICES="4,5,6,7" +accelerate launch --use_fsdp --config_file llama/accelerate/fsdp_7b.yaml -m llama.train +variant="ft_llama3_8b_instruct" + +# For 13b, you need 6 GPUs with 48GB VRAM +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5" +accelerate launch --use_fsdp --config_file llama/accelerate/fsdp_13b.yaml -m llama.train +variant="ft_13b" +``` + +Results will be saved in `./results` and checkpoints in `./checkpoints`. + +#### Run LLaMA on Evaluation Splits + +You need to specify which `eval.split` you want to evaluate on. For example, to evaluate on the `iid` split, you can run the following command: + +```bash +export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use + +# On just one split +python -m llama.eval +variant="ft_1.3b" eval.split=valid + +# On multiple splits (e.g. test_iid, test_vis) +python -m llama.eval -m +variant="ft_2.7b" eval.split=test_iid,test_web,test_geo,test_cat,test_vis + +# Evaluating llama-3-8b-instruct on all splits +python -m llama.eval -m +variant="ft_llama3_8b_instruct" eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis +``` + +### Evaluation + +To run the evaluation metrics, you can use the following command (from this directory): + +```bash +python -m weblinx.eval -d results -b ./wl_data/demonstrations +``` + +In this case, `-b` is the base directory for the demonstrations, and `-d` is the directory containing the results (generated above by the `llama.eval` script). This will automatically run the evaluation metrics and save the results in the `results/aggregated_scores.json` directory. If you are only interested in the overall score for a split (e.g. `valid`), you can find look for the following entry in the aggregated score file (as an example): + +```json +// ... + { + "split": "valid", + "intent": "overall", + "metric": "overall", + "model_name": "princeton-nlp/Sheared-LLaMA-1.3B", + "project_name": "llama_ft", + "score": 0.21667765869744438, + "unconditional_score": 0.15307513104251605 + }, +// ... +``` + +Behind the scene, this will use the `weblinx.eval.auto_eval_and_save` function to run the evaluation metrics. If you want more control, you can also use that `weblinx.eval.auto_eval_and_save` function directly if you prefer; for an example, check out `weblinx/eval/__main__.py`. + +Note that it might be slow the first time you run, because it reads a lot of demonstrations and load millions of files. However, a demo-level cache is automatically created (see `./.cache/demonstrations`), so the next time you run it, it should be much faster. + +### More models + +#### Flan-T5 + +```bash +export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use + +# Base +python -m flan.train +python -m flan.eval -m eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis + + +# Large +python -m modeling.flan.train model.name=google/flan-t5-large +python -m modeling.flan.eval -m model.name=google/flan-t5-large eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis + +# XL +python -m modeling.flan.train +variant=ft_xl +python -m modeling.flan.eval +variant=ft_xl eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis +``` + +#### MindAct + +```bash +export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use + +# Base +python -m modeling.flan.train +variant=ft_mindact model.size=base +python -m modeling.flan.eval -m +variant=ft_mindact model.size=base eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis + +# Large +python -m modeling.flan.train +variant=ft_mindact model.size=large +python -m modeling.flan.eval -m +variant=ft_mindact model.size=large eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis + +# XL +python -m modeling.flan.train +variant=ft_mindact_xl +python -m modeling.flan.eval -m +variant=ft_mindact_xl eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis +``` + + +#### Pix2Act + +First, you will need to download the tff file for the Arial font (aka `Arial.TFF`) and place it at `${project_dir}/modeling/fonts/Arial.TTF`. On Windows, you can find it at `C:\windows\fonts\`. On Linux, you can find alternative fonts at `/usr/share/fonts/truetype/`. + +```bash +export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use + +# Base +python -m modeling.pix2act.train +python -m modeling.pix2act.eval eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis + +# Large +python -m modeling.pix2act.train +variant=ft_large +python -m modeling.pix2act.eval +variant=ft_large eval.split=valid,test_iid,test_web,test_geo,test_cat,test_vis +``` \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/dataset.py b/src/dataset.py new file mode 100644 index 0000000..3ca9861 --- /dev/null +++ b/src/dataset.py @@ -0,0 +1,9 @@ +from huggingface_hub import snapshot_download + +snapshot_download( + repo_id="McGill-NLP/WebLINX-full", + repo_type="dataset", + local_dir="./wl_data", + ignore_patterns=["**/bboxes/", "**/pages/", "**/screenshots/", "**/video.mp4"], + resume_download=True, +) diff --git a/src/llama/__init__.py b/src/llama/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/llama/accelerate/fsdp_13b.yaml b/src/llama/accelerate/fsdp_13b.yaml new file mode 100644 index 0000000..afbbc1d --- /dev/null +++ b/src/llama/accelerate/fsdp_13b.yaml @@ -0,0 +1,28 @@ +# Useful: https://huggingface.co/docs/accelerate/main/en/usage_guides/fsdp +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: 1 + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sync_module_states: true + # Set fsdp_use_orig_params=true if using peft: + # https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019 + fsdp_use_orig_params: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 6 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/src/llama/accelerate/fsdp_7b.yaml b/src/llama/accelerate/fsdp_7b.yaml new file mode 100644 index 0000000..488cef9 --- /dev/null +++ b/src/llama/accelerate/fsdp_7b.yaml @@ -0,0 +1,28 @@ +# Useful: https://huggingface.co/docs/accelerate/main/en/usage_guides/fsdp +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: 1 + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sync_module_states: true + # Set fsdp_use_orig_params=true if using peft: + # https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019 + fsdp_use_orig_params: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/src/llama/conf/config.yaml b/src/llama/conf/config.yaml new file mode 100644 index 0000000..9029ab4 --- /dev/null +++ b/src/llama/conf/config.yaml @@ -0,0 +1,56 @@ +project_dir: ${oc.env:WEBLINX_PROJECT_DIR} +seed: 123 +project_name: llama + +data: + num_proc: 8 + split_path: ${project_dir}/wl_data/splits.json + base_dir: ${project_dir}/wl_data/demonstrations/ + +train: + split: train + num_epochs: 3 + learning_rate: 5e-5 + batch_size_per_device: 4 + gradient_accumulation_steps: 4 + dataloader_num_workers: 8 + gradient_checkpointing: True + use_accelerator_device_map: False + use_auto_device_map: True + warmup_ratio: 0 + scheduler: linear + optim: adamw_torch + +eval: + split: valid + batch_size_per_device: 2 + result_dir: ${project_dir}/results/${project_name}/${eval.split}/${model.name} + load_from_save_dir: False # If True, load from model.save_dir instead of model.name + +model: + name: meta-llama/Llama-2-7b-chat-hf + tokenizer: ${model.name} + template_tokenizer: ${model.tokenizer} + max_inp_len: null + max_out_len: 256 + use_rope: True + use_flash_attention_2: True + save_dir: ${project_dir}/checkpoints/${project_name}/${model.name} + +candidates: + k: 10 + model: McGill-NLP/MiniLM-L6-dmr # unused but potentially useful + project_name: dmr # unused but potentially useful + split: ${eval.split} + train_path: ${project_dir}/wl_data/candidates/train.jsonl + path: ${project_dir}/wl_data/candidates/${candidates.split}.jsonl + +hydra: + run: + dir: ${project_dir}/logs/${project_name}/${hydra.job.name}/${now:%Y-%m-%d-%H:%M:%S} + # Use the same for sweep's subdir + sweep: + dir: ${hydra.run.dir} + job: + chdir: False + verbose: INFO \ No newline at end of file diff --git a/src/llama/conf/variant/ft_1.3b.yaml b/src/llama/conf/variant/ft_1.3b.yaml new file mode 100644 index 0000000..8fbd1ec --- /dev/null +++ b/src/llama/conf/variant/ft_1.3b.yaml @@ -0,0 +1,11 @@ +# @package _global_ +project_name: llama_ft + +model: + # This is meant to be run on 1 gpu with 48GB+ memory (e.g., A6000) + use_flash_attention_2: True + name: princeton-nlp/Sheared-LLaMA-1.3B + +eval: + batch_size_per_device: 8 + load_from_save_dir: True \ No newline at end of file diff --git a/src/llama/conf/variant/ft_13b.yaml b/src/llama/conf/variant/ft_13b.yaml new file mode 100644 index 0000000..84ccf17 --- /dev/null +++ b/src/llama/conf/variant/ft_13b.yaml @@ -0,0 +1,18 @@ +# @package _global_ +project_name: llama_ft + +model: + # This is meant to be run on 6 gpus with 48GB+ memory (e.g., A6000) + use_flash_attention_2: True + name: meta-llama/Llama-2-13b-chat-hf + +train: + # 6 (# gpus) * 3 (accum steps) * 1 (bsize) = 18 (batch size) + batch_size_per_device: 1 + gradient_accumulation_steps: 3 + use_accelerator_device_map: True + use_auto_device_map: False + +eval: + batch_size_per_device: 4 + load_from_save_dir: True \ No newline at end of file diff --git a/src/llama/conf/variant/ft_2.7b.yaml b/src/llama/conf/variant/ft_2.7b.yaml new file mode 100644 index 0000000..45829fc --- /dev/null +++ b/src/llama/conf/variant/ft_2.7b.yaml @@ -0,0 +1,11 @@ +# @package _global_ +project_name: llama_ft + +model: + # This is meant to be run on 1 gpu with 48GB+ memory (e.g., A6000) + use_flash_attention_2: True + name: princeton-nlp/Sheared-LLaMA-2.7B + +eval: + batch_size_per_device: 8 + load_from_save_dir: True \ No newline at end of file diff --git a/src/llama/conf/variant/ft_7b.yaml b/src/llama/conf/variant/ft_7b.yaml new file mode 100644 index 0000000..d29509c --- /dev/null +++ b/src/llama/conf/variant/ft_7b.yaml @@ -0,0 +1,18 @@ +# @package _global_ +project_name: llama_ft + +model: + # This is meant to be run on 4 gpus with 48GB+ memory (e.g., A6000) + use_flash_attention_2: True + name: meta-llama/Llama-2-7b-chat-hf + +train: + # 4 (# gpus) * 4 (accum steps) * 1 (bsize) = 16 (batch size) + batch_size_per_device: 4 + gradient_accumulation_steps: 1 + use_accelerator_device_map: True + use_auto_device_map: False + +eval: + batch_size_per_device: 8 + load_from_save_dir: True \ No newline at end of file diff --git a/src/llama/conf/variant/ft_llama3_8b_instruct.yaml b/src/llama/conf/variant/ft_llama3_8b_instruct.yaml new file mode 100644 index 0000000..9703619 --- /dev/null +++ b/src/llama/conf/variant/ft_llama3_8b_instruct.yaml @@ -0,0 +1,18 @@ +# @package _global_ +project_name: llama_ft # TODO: Change this to your project name + +model: + # This is meant to be run on 4 gpus with 48GB+ memory (e.g., A6000) + use_flash_attention_2: True + name: meta-llama/Meta-Llama-3-8B-Instruct + +train: + # 4 (# gpus) * 4 (accum steps) * 1 (bsize) = 16 (batch size) + batch_size_per_device: 4 + gradient_accumulation_steps: 1 + use_accelerator_device_map: True + use_auto_device_map: False + +eval: + batch_size_per_device: 8 + load_from_save_dir: True \ No newline at end of file diff --git a/src/llama/eval.py b/src/llama/eval.py new file mode 100644 index 0000000..22c23a6 --- /dev/null +++ b/src/llama/eval.py @@ -0,0 +1,138 @@ +import json +import logging +from functools import partial +from pathlib import Path + +import hydra +import torch +import weblinx as wl +from omegaconf import OmegaConf +from tqdm import tqdm +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + pipeline, +) +from transformers.pipelines.pt_utils import KeyDataset +from weblinx.processing import load_candidate_elements +from weblinx.processing.prompt import ( + build_input_records_from_selected_turns, + select_turns_and_candidates_for_prompts, +) +from weblinx.utils.hydra import save_path_to_hydra_logs + +from .processing import ( + build_formatter_for_multichoice, + build_prompt_records_for_llama_truncated, + insert_formatted_chat_into_records, +) + + +@hydra.main(version_base=None, config_path="conf", config_name="config") +def main(cfg): + logger = logging.getLogger(__name__) + + split_path = Path(cfg.data.split_path).expanduser() + result_dir = Path(cfg.eval.result_dir).expanduser() + model_save_dir = Path(cfg.model.save_dir).expanduser() + + max_out_len = cfg.model.max_out_len + split = cfg.eval.split + + result_dir.mkdir(parents=True, exist_ok=True) + + logger.info(OmegaConf.to_yaml(cfg)) + + candidates = load_candidate_elements(path=cfg.candidates.path) + + tokenizer = AutoTokenizer.from_pretrained(cfg.model.tokenizer, padding_side="left") + tokenizer.pad_token = tokenizer.eos_token + + # Data loading + demo_names = wl.utils.load_demo_names_in_split(split_path, split=split) + demos = [wl.Demonstration(name, base_dir=cfg.data.base_dir) for name in demo_names] + + format_intent = build_formatter_for_multichoice() + build_prompt_records_fn = partial( + build_prompt_records_for_llama_truncated, + format_intent=format_intent, + tokenizer=tokenizer, + ) + + selected_turns = select_turns_and_candidates_for_prompts( + demos=demos, + candidates=candidates, + num_candidates=cfg.candidates.k, + ) + + input_records = build_input_records_from_selected_turns( + selected_turns=selected_turns, + format_intent=format_intent, + build_prompt_records_fn=build_prompt_records_fn, + format_prompt_records_fn=None, + ) + + template_tokenizer = AutoTokenizer.from_pretrained(cfg.model.template_tokenizer) + insert_formatted_chat_into_records( + records=input_records, + tokenizer=template_tokenizer, + include_output_target=False, + ) + + model_kwargs = dict(device_map="auto", torch_dtype=torch.bfloat16) + + if cfg.model.use_rope: + model_kwargs["rope_scaling"] = {"type": "dynamic", "factor": 2.0} + + if cfg.model.use_flash_attention_2: + model_kwargs["use_flash_attention_2"] = True + + if cfg.eval.get("load_from_save_dir", False) is True: + model_load_name = str(model_save_dir) + else: + model_load_name = cfg.model.name + + model = AutoModelForCausalLM.from_pretrained(model_load_name, **model_kwargs) + + dset = KeyDataset(input_records, key="text") + pipe = pipeline( + "text-generation", model=model, tokenizer=tokenizer, torch_dtype=torch.bfloat16 + ) + pipe_kwargs = dict( + max_new_tokens=max_out_len, + return_full_text=False, + batch_size=cfg.eval.batch_size_per_device, + pad_token_id=tokenizer.eos_token_id, + ) + + results = [] + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pbar = tqdm( + pipe(dset, **pipe_kwargs), desc="Generating outputs", total=len(dset) + ) + for i, out in enumerate(pbar): + rec = input_records[i] + generated_text = out[0]["generated_text"] + result = { + "demo_name": rec["demo_name"], + "turn_index": rec["turn_index"], + "prompt": rec["prompt"], + "text": rec["text"], + "output_predicted": generated_text, + "output_target": rec["output_target"], + "output_target_dict": rec["output_target_dict"], + } + + results.append(result) + + # Save results + with Path.open(result_dir / "results.json", "w") as f: + json.dump(results, f, indent=2) + + # Save the path to hydra_path into the model directory + save_path_to_hydra_logs(save_dir=result_dir) + + +if __name__ == "__main__": + main() diff --git a/src/llama/processing.py b/src/llama/processing.py new file mode 100644 index 0000000..020b6e4 --- /dev/null +++ b/src/llama/processing.py @@ -0,0 +1,307 @@ +from collections.abc import Callable +from copy import deepcopy +from functools import partial + +import lxml.html +import weblinx.utils.format as wlf +from weblinx.processing.dom import clean_and_prune_tree +from weblinx.processing.prompt import ( + find_turns_with_instructor_chat, + format_candidates, + format_utterances, + format_utterances_truncated, + get_speaker, + multi_attempt_format_prev_turns_truncated, +) +from weblinx.processing.truncation import ( + multi_attempt_truncate_cands_turn, + multi_attempt_truncate_dom_tree, +) + + +def build_formatter_for_multichoice(): + format_click = partial(wlf.format_click, formatters=(wlf.format_uid,)) + format_text_input = partial( + wlf.format_text_input, + formatters=( + partial(wlf.format_arg_item, name="text", max_length=200), + wlf.format_uid, + ), + ) + format_change = partial( + wlf.format_change, + formatters=( + partial(wlf.format_arg_item, name="value", max_length=200), + wlf.format_uid, + ), + ) + format_submit = partial(wlf.format_submit, formatters=(wlf.format_uid,)) + format_load = partial( + wlf.format_load, + include_transition=False, + include_timestamp=False, + max_length=200, + ) + format_scroll = partial(wlf.format_scroll, include_timestamp=False) + + format_say = partial(wlf.format_say, include_timestamp=False) + + format_intent_auto = partial( + wlf.format_intent_automatically, + format_change=format_change, + format_click=format_click, + format_load=format_load, + format_say=format_say, + format_scroll=format_scroll, + format_submit=format_submit, + format_text_input=format_text_input, + ) + + return format_intent_auto + + +def get_system_prompt_template_for_llama_mc_concise(): + sys_prompt_template = ( + "You are an AI assistant with a deep understanding of HTML " + "and you must predict actions based on a user request, which will be executed. " + "Use one of the following, replacing [] with an appropriate value: " + "change(value=[str], uid=[str]) ; " + "click(uid=[str]) ; " + "load(url=[str]) ; " + 'say(speaker="navigator", utterance=[str]) ; ' + "scroll(x=[int], y=[int]) ; " + "submit(uid=[str]) ;" + "text_input(text=[str], uid=[str]) ;\n" + "The user's first and last {num_utterances} utterances are: " + "{utterance_context} ;\n" + "Viewport size: {height}h x {width}w ;\n" + "Only the last {num_prev_turns} turns are provided." + ) + + return sys_prompt_template + + +def get_candidate_prompt_template_for_llama(): + return "Here are the top candidates for this turn: {candidate_str}\n" + + +def get_final_user_message(): + return ( + "Please select the best action using the correct format, " + "do not provide any other information or explanation." + ) + + +def merge_prev_turns(prev_turns_text_list, final_user_message): + prev_turns_merged: list[dict[str, str]] = [] + + # Merge turns from the same role + for i, turn_text in enumerate(prev_turns_text_list): + role = get_speaker( + turn_text, + instructor_name="user", + navigator_name="assistant", + default_name="unknown", + ) + + if i > 0 and prev_turns_merged[-1]["role"] == role: + prev_turns_merged[-1]["content"] += " " + turn_text + else: + prev_turns_merged.append({"role": role, "content": turn_text}) + + if len(prev_turns_merged) > 0 and prev_turns_merged[-1]["role"] == "user": + prev_turns_merged[-1]["content"] += " " + final_user_message + else: + prev_turns_merged.append({"role": "user", "content": final_user_message}) + + return prev_turns_merged + + +def build_prompt_records_for_llama_truncated( # noqa: PLR0913 + replay, + turn, + format_intent, + tokenizer, + cands_turn=None, + num_utterances=5, + num_prev_turns=5, + system_prompt_template=None, + candidate_prompt_template=None, + final_user_message=None, + include_html=True, + format_candidates_fn=partial( # noqa: B008 + format_candidates, max_char_len=None, use_uid_as_rank=True + ), + merge_prev_turns_fn=merge_prev_turns, + format_output_dict_fn: Callable = partial( # noqa: B008 + wlf.format_output_dictionary, function_key="intent" + ), + max_html_tokens=700, + max_utterance_tokens=40 * 5, + max_prev_turns_tokens=50 * 5, + max_candidates_tokens=65 * 10, + add_unused_len_to_cands=True, + allow_iterative_reduction=False, + use_tokenizer_template=False, + template_tokenizer=None, + parser=None, +) -> list[dict[str, str]]: + """Parameters + + ---------- + ... + allow_iterative_reduction : bool + This arg is only relevant when truncate_at_center is used behind the scene (e.g. for + multi_attempt_format_prev_turns_truncated or multi_attempt_truncate_dom_tree). If True, + then we will allow the iterative reduction to continue until the max_tokens is reached. + This is useful when the tokenizer output does not necessarily decrease when we remove + tokens from the input. For example, if we remove a token that is part of a word, but + the updated text is retokenized to the same number of tokens, then we will continue + to remove tokens until we reach the max_tokens limit. + """ + if system_prompt_template is None: + system_prompt_template = get_system_prompt_template_for_llama_mc_concise() + + if candidate_prompt_template is None: + candidate_prompt_template = get_candidate_prompt_template_for_llama() + + if final_user_message is None: + final_user_message = get_final_user_message() + + instructor_chat_turns = find_turns_with_instructor_chat( + replay, turn, num_prev_turns=num_prev_turns + ) + utterance_context = format_utterances_truncated( + instructor_chat_turns, + tokenizer=tokenizer, + max_tokens=max_utterance_tokens, + num_utterances=num_utterances, + format_utterances_fn=format_utterances, + allow_iterative_reduction=allow_iterative_reduction, + ) + + prev_turns_text_list = multi_attempt_format_prev_turns_truncated( + replay=replay, + turn=turn, + format_intent=partial(format_intent, return_as=dict), + tokenizer=tokenizer, + num_prev_turns=num_prev_turns, + turn_sep=None, # output list + max_tokens=max_prev_turns_tokens, + max_attempts=5, + format_output_dict_fn=format_output_dict_fn, + warn_after_attempts=False, + allow_iterative_reduction=allow_iterative_reduction, + ) + + prev_turns_merged = merge_prev_turns_fn( + prev_turns_text_list=prev_turns_text_list, final_user_message=final_user_message + ) + + sys_prompt = system_prompt_template.format( + num_utterances=num_utterances - 1, # 1 less since we add the first utterance + utterance_context=utterance_context, + height=turn.viewport_height, + width=turn.viewport_width, + num_prev_turns=num_prev_turns, + ) + + if include_html and turn.html not in ["", None] and cands_turn is not None: + dom_tree_raw = lxml.html.fromstring(turn.html, parser=parser) + dom_tree_pruned = clean_and_prune_tree(dom_tree_raw, cands_turn=cands_turn) + trunc = multi_attempt_truncate_dom_tree( + dom_tree=dom_tree_pruned, + tokenizer=tokenizer, + max_tokens=max_html_tokens, + warn_after_attempts=False, + allow_iterative_reduction=allow_iterative_reduction, + ) + html = trunc["tree_repr"] + sys_prompt = html + sys_prompt + else: + html = "" + + if cands_turn is not None: + if add_unused_len_to_cands: + # Add the unused length to the candidates + num_html_tokens = len(tokenizer.tokenize(html)) + num_utter_tokens = len(tokenizer.tokenize(utterance_context)) + if use_tokenizer_template: + if template_tokenizer is None: + msg = "template_tokenizer must be provided when use_tokenizer_template is True." + raise ValueError(msg) + prev_turns_merged_copy = deepcopy(prev_turns_merged) + if prev_turns_merged[0]["role"] == "assistant": + # insert a dummy user turn + prev_turns_merged_copy.insert(0, {"role": "user", "content": ""}) + num_prev_turns_tokens = len( + template_tokenizer.apply_chat_template( + [{"role": "system", "content": ""}, *prev_turns_merged_copy], + tokenize=True, + ) + ) + else: + num_prev_turns_tokens = len( + tokenizer.tokenize(" ".join(prev_turns_text_list)) + ) + remain_html_tokens = max_html_tokens - num_html_tokens + remain_utter_tokens = max_utterance_tokens - num_utter_tokens + remain_prev_turns_tokens = max_prev_turns_tokens - num_prev_turns_tokens + remain_tokens = ( + remain_html_tokens + remain_utter_tokens + remain_prev_turns_tokens + ) + # Add the unused length to the max_candidates_tokens + max_candidates_tokens += remain_tokens + + cands_turn_trunc = multi_attempt_truncate_cands_turn( + cands_turn=cands_turn, + tokenizer=tokenizer, + max_tokens=max_candidates_tokens, + format_candidates_fn=format_candidates_fn, + warn_after_attempts=False, + allow_iterative_reduction=allow_iterative_reduction, + ) + cand_str = format_candidates_fn(cands_turn_trunc, max_char_len=None) + cand_prompt = candidate_prompt_template.format(candidate_str=cand_str) + sys_prompt += "\n" + cand_prompt + + return [{"role": "system", "content": sys_prompt}, *prev_turns_merged] + + +def __insert_empty_user_content_at_first(prompt: list) -> None: + """Given a list of dictionary representing the input prompt, insert an empty user content at the first position. + + after system content, only if it is not already a user content. This is done in place. + """ + if prompt[0]["role"] != "system": + msg = f"First prompt must be a system prompt. Got {prompt[0]['role']} instead." + raise ValueError(msg) + + if prompt[1]["role"] != "user": + prompt.insert(1, {"role": "user", "content": ""}) + + +def insert_formatted_chat_into_records( + records, + tokenizer, + include_output_target=True, + origin_key="prompt", + text_key="text", +): + """Given a list of records, insert the formatted chat into the records. This is done in place. + + Note that we need a tokenizer's `apply_chat_template` method to be available. + """ + for i, record in enumerate(records): + __insert_empty_user_content_at_first(record[origin_key]) + + if include_output_target: + target = [{"role": "assistant", "content": record["output_target"]}] + combined = record[origin_key] + target + else: + combined = record[origin_key] + + text = tokenizer.apply_chat_template( + combined, tokenize=False, add_generation_prompt=False + ) + records[i][text_key] = text diff --git a/src/llama/train.py b/src/llama/train.py new file mode 100644 index 0000000..3155bca --- /dev/null +++ b/src/llama/train.py @@ -0,0 +1,135 @@ +import json +import logging +import typing +from functools import partial +from pathlib import Path + +import datasets +import hydra +import torch +import weblinx as wl +from accelerate import Accelerator +from omegaconf import OmegaConf +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + TrainingArguments, +) +from trl import SFTTrainer +from weblinx.processing import load_candidate_elements +from weblinx.processing.prompt import ( + build_input_records_from_selected_turns, + select_turns_and_candidates_for_prompts, +) +from weblinx.utils import set_seed +from weblinx.utils.hydra import save_path_to_hydra_logs + +from .processing import ( + build_formatter_for_multichoice, + build_prompt_records_for_llama_truncated, + insert_formatted_chat_into_records, +) + + +@hydra.main(config_path="conf", config_name="config", version_base=None) +def main(cfg): + set_seed(cfg.seed) + split_path = Path(cfg.data.split_path).expanduser() + model_save_dir = Path(cfg.model.save_dir).expanduser() + model_save_dir.mkdir(exist_ok=True, parents=True) + logging.info(OmegaConf.to_yaml(cfg)) + + demo_names = wl.utils.load_demo_names_in_split(split_path, split=cfg.train.split) + demos = [ + wl.Demonstration(demo_name, base_dir=cfg.data.base_dir) + for demo_name in demo_names + ] + candidates = load_candidate_elements(path=cfg.candidates.train_path) + + tokenizer = AutoTokenizer.from_pretrained(cfg.model.tokenizer, padding_side="right") + tokenizer.pad_token = tokenizer.eos_token + + model_kwargs: dict[str, typing.Any] = dict(torch_dtype=torch.bfloat16) + + if cfg.train.use_accelerator_device_map: + accelerator = Accelerator() + model_kwargs["device_map"] = {"": accelerator.process_index} + + elif cfg.train.use_auto_device_map: + model_kwargs["device_map"] = "auto" + + if cfg.model.use_flash_attention_2: + model_kwargs["use_flash_attention_2"] = True + + model = AutoModelForCausalLM.from_pretrained(cfg.model.name, **model_kwargs) + + format_intent = build_formatter_for_multichoice() + input_records_fname = "input_records_trunc.json" + build_prompt_records_fn = partial( + build_prompt_records_for_llama_truncated, + format_intent=format_intent, + tokenizer=tokenizer, + ) + + selected_turns = select_turns_and_candidates_for_prompts( + demos=demos, + candidates=candidates, + num_candidates=cfg.candidates.k, + ) + + input_records = build_input_records_from_selected_turns( + selected_turns=selected_turns, + format_intent=format_intent, + build_prompt_records_fn=build_prompt_records_fn, + format_prompt_records_fn=None, + ) + + template_tokenizer = AutoTokenizer.from_pretrained(cfg.model.template_tokenizer) + insert_formatted_chat_into_records( + input_records, template_tokenizer, include_output_target=True + ) + + with Path.open(model_save_dir.joinpath(input_records_fname), "w") as f: + json.dump(input_records, f, indent=2) + + input_records_texts = [{"text": record["text"]} for record in input_records] + + training_args = TrainingArguments( + output_dir=str(model_save_dir), + optim=cfg.train.optim, + learning_rate=cfg.train.learning_rate, + num_train_epochs=cfg.train.num_epochs, + per_device_train_batch_size=cfg.train.batch_size_per_device, + gradient_accumulation_steps=cfg.train.gradient_accumulation_steps, + gradient_checkpointing=cfg.train.gradient_checkpointing, + warmup_ratio=cfg.train.warmup_ratio, + lr_scheduler_type=cfg.train.scheduler, + save_strategy="no", + evaluation_strategy="no", + logging_strategy="epoch", + logging_first_step=True, + prediction_loss_only=True, + bf16=True, + bf16_full_eval=True, + ) + + trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + args=training_args, + train_dataset=datasets.Dataset.from_list(input_records_texts), + max_seq_length=model.config.max_position_embeddings, + dataset_text_field="text", + ) + + trainer.train() + + # Save model, tokenizer, trainer state, and path to hydra logs + trainer.save_model(model_save_dir) + tokenizer.save_pretrained(model_save_dir) + trainer.state.save_to_json(model_save_dir / "trainer_state.json") + save_path_to_hydra_logs(save_dir=model_save_dir) + + +if __name__ == "__main__": + main() diff --git a/webnavix.code-workspace b/webnavix.code-workspace new file mode 100644 index 0000000..a26ff12 --- /dev/null +++ b/webnavix.code-workspace @@ -0,0 +1,8 @@ +{ + "folders": [ + { + "name": "webnavix", + "path": "." + } + ] +}