From 0580b06af968cf8b7b4c74f26cd6b69f56094643 Mon Sep 17 00:00:00 2001 From: Ryota Murai <35632215+rmuraix@users.noreply.github.com> Date: Thu, 7 Nov 2024 19:12:42 +0900 Subject: [PATCH] feat: add notebooks (#2) * feat: add playground.ipynb * feat: add fine-tuning.ipynb --- src/fine-tuning.ipynb | 439 ++++++++++++++++++++++++++++++++++++++++++ src/playground.ipynb | 157 +++++++++++++++ 2 files changed, 596 insertions(+) create mode 100644 src/fine-tuning.ipynb create mode 100644 src/playground.ipynb diff --git a/src/fine-tuning.ipynb b/src/fine-tuning.ipynb new file mode 100644 index 0000000..2c2158d --- /dev/null +++ b/src/fine-tuning.ipynb @@ -0,0 +1,439 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ファインチューニング - 入門 Hugging Face🤗\n", + "\n", + "[GitHub](https://github.com/tpu-dsg/hf-hands-on)\n", + "\n", + "このノートブックでは、[Hugging Face🤗](https://huggingface.co/)のエコシステムを活用して、ファインチューニングによる画像分類を行います。\n", + "\n", + "NOTICE: [Hugging Faceのガイド](https://huggingface.co/docs/transformers/ja/tasks/image_classification)を参考に作成されました。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "import evaluate\n", + "from transformers import (\n", + " AutoImageProcessor,\n", + " AutoModelForImageClassification,\n", + " TrainingArguments,\n", + " Trainer,\n", + " DefaultDataCollator,\n", + ")\n", + "from datasets import load_dataset\n", + "from huggingface_hub import notebook_login\n", + "import albumentations as A\n", + "from albumentations.pytorch import ToTensorV2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 使用するモデル\n", + "CHECKPOINT: str = \"google/vit-base-patch16-224-in21k\"\n", + "# 訓練後のモデルの名前\n", + "TUNED_MODEL_NAME: str = \"my_awesome_food_model\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image_processor = AutoImageProcessor.from_pretrained(CHECKPOINT)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## データセットの準備\n", + "\n", + "Food-101データセットのサブセットをロードし、データセットの`train`をtrainセットとtestセットに分割します。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "food = load_dataset(\"food101\", split=\"train[:5000]\")\n", + "food = food.train_test_split(test_size=0.2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "一つデータを見てみましょう。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "food[\"train\"][0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "データセット内の各例には 2 つのフィールドがあります。\n", + "\n", + "- `image`: 食品の PIL 画像\n", + "- `label`: 食品のラベルクラス" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "モデルがラベル ID からラベル名を取得しやすくするために、ラベル名をマップする辞書を作成します。 " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "labels = food[\"train\"].features[\"label\"].names\n", + "label2id, id2label = dict(), dict()\n", + "\n", + "for i, label in enumerate(labels):\n", + " label2id[label] = str(i)\n", + " id2label[str(i)] = label" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "これで、ラベル ID をラベル名に変換できるようになりました。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "id2label[str(79)]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[albumentations](https://albumentations.ai/)を使用したデータ拡張を定義します。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_transforms = A.Compose(\n", + " [\n", + " A.Resize(image_processor.size[\"height\"], image_processor.size[\"width\"]),\n", + " A.RandomCrop(image_processor.size[\"height\"], image_processor.size[\"width\"]),\n", + " A.HorizontalFlip(p=0.5),\n", + " A.RandomBrightnessContrast(p=0.2),\n", + " A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std),\n", + " ToTensorV2(),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "次に、変換を適用し画像の`pixel_values`(モデルへの入力) を返す前処理関数を作成します。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def transforms(examples):\n", + " examples[\"pixel_values\"] = [\n", + " _transforms(image=np.array(img.convert(\"RGB\")))[\"image\"]\n", + " for img in examples[\"image\"]\n", + " ]\n", + " del examples[\"image\"]\n", + " return examples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "food = food.with_transform(transforms)\n", + "data_collator = DefaultDataCollator()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 学習の設定" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "メトリクスの計算方法を定義\n", + "\n", + "今回は正解率を最大化するようにします。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "accuracy = evaluate.load(\"accuracy\")\n", + "\n", + "\n", + "def compute_metrics(eval_pred):\n", + " predictions, labels = eval_pred\n", + " predictions = np.argmax(predictions, axis=1)\n", + " return accuracy.compute(predictions=predictions, references=labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "モデルをロード" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = AutoModelForImageClassification.from_pretrained(\n", + " CHECKPOINT,\n", + " num_labels=len(labels),\n", + " id2label=id2label,\n", + " label2id=label2id,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "トレーニング引数をTrainerに渡します。\n", + "\n", + "Tips: `push_to_hub=True`を設定すると、このモデルをHubにプッシュできます。(Huffing Faceへのログインが必要です):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# notebook_login()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "training_args = TrainingArguments(\n", + " output_dir=TUNED_MODEL_NAME,\n", + " remove_unused_columns=False,\n", + " eval_strategy=\"epoch\",\n", + " save_strategy=\"epoch\",\n", + " learning_rate=5e-5,\n", + " per_device_train_batch_size=16,\n", + " gradient_accumulation_steps=4,\n", + " per_device_eval_batch_size=16,\n", + " num_train_epochs=10,\n", + " warmup_ratio=0.1,\n", + " logging_steps=10,\n", + " load_best_model_at_end=True,\n", + " metric_for_best_model=\"accuracy\",\n", + " push_to_hub=False,\n", + ")\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " data_collator=data_collator,\n", + " train_dataset=food[\"train\"],\n", + " eval_dataset=food[\"test\"],\n", + " processing_class=image_processor,\n", + " compute_metrics=compute_metrics,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 学習\n", + "\n", + "`TUNED_MODEL_NAME`に指定した名前のディレクトリにチェックポイントが保存されます。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "モデルの保存" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.save_model(\"./\" + TUNED_MODEL_NAME)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`push_to_hub=True`を設定し、ログイン済みであれば、以下を実行することでモデルを公開できます。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# trainer.push_to_hub()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 推論" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "データの読み込み" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = load_dataset(\"food101\", split=\"validation[:10]\")\n", + "image = ds[\"image\"][0]\n", + "image" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "先ほど学習したモデルを使用して推論を実行します。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image_processor = AutoImageProcessor.from_pretrained(\"./\" + TUNED_MODEL_NAME)\n", + "inputs = image_processor(image, return_tensors=\"pt\")\n", + "\n", + "model = AutoModelForImageClassification.from_pretrained(\"./\" + TUNED_MODEL_NAME)\n", + "\n", + "with torch.no_grad():\n", + " logits = model(**inputs).logits" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "結果の表示" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predicted_label = logits.argmax(-1).item()\n", + "\n", + "model.config.id2label[predicted_label]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/playground.ipynb b/src/playground.ipynb new file mode 100644 index 0000000..a6e27bd --- /dev/null +++ b/src/playground.ipynb @@ -0,0 +1,157 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model Playground - 入門 Hugging Face🤗\n", + "\n", + "[GitHub](https://github.com/tpu-dsg/hf-hands-on)\n", + "\n", + "このノートブックでは、[Hugging Face🤗](https://huggingface.co/)のエコシステムを活用して、学習済みモデルを用いた画像分類を行います。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "from urllib.parse import urlparse\n", + "import os\n", + "import sys\n", + "\n", + "from transformers import AutoImageProcessor, AutoModelForImageClassification\n", + "from PIL import Image\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "以下で推論に使用する画像分類モデルと、推論に使用する画像を指定しましょう。\n", + "\n", + "- 画像分類モデルはこちらから探してみましょう:https://huggingface.co/models?pipeline_tag=image-classification&sort=trending\n", + "- `apple/mobilevit-small`のページはこちら:https://huggingface.co/apple/mobilevit-small" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 推論に使用するモデル\n", + "MODEL_NAME: str = \"apple/mobilevit-small\"\n", + "# 推論に使用する画像のURLまたはパス\n", + "IMAGE: str = \"https://images.unsplash.com/photo-1523974837767-33c0fbdd9f6a?q=80&w=1974&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "parsed_url = urlparse(IMAGE)\n", + "if parsed_url.scheme in (\"http\", \"https\"):\n", + " # URLが指定された場合\n", + " img = Image.open(requests.get(IMAGE, stream=True).raw)\n", + "elif os.path.exists(IMAGE):\n", + " # ファイルパスが指定された場合\n", + " img = Image.open(IMAGE)\n", + "else:\n", + " sys.exit(\"有効な画像のURLまたはパスを指定してください\")\n", + "\n", + "img.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## プロセッサーとモデルの読み込み\n", + "\n", + "モデルと[Image Processor](https://huggingface.co/docs/transformers/main_classes/image_processor)を読み込みます。Image Processorはモデルが期待する形式に画像を変換する役割を担います。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 画像の前処理用のプロセッサ\n", + "processor = AutoImageProcessor.from_pretrained(MODEL_NAME)\n", + "# モデルの読み込み\n", + "model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 推論" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 画像を変換\n", + "input = processor(images=img, return_tensors=\"pt\")\n", + "\n", + "# モデルの推論\n", + "with torch.no_grad():\n", + " output = model(**input)\n", + " logits = output.logits\n", + "\n", + "# 推論結果の表示\n", + "predicted_class_idx = logits.argmax(-1).item()\n", + "print(f\"Predicted class: {model.config.id2label[predicted_class_idx]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "コードを書き換えて、他にも色々と試してみましょう!\n", + "\n", + "- 他の画像分類モデルを使用してみる\n", + "- **分類タスク以外**のモデルを使用してみる\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}