From c87cb4e7618ec23b32543cdc4c851a57764361cd Mon Sep 17 00:00:00 2001 From: rdorrepa Date: Mon, 22 Apr 2024 11:36:58 +0100 Subject: [PATCH] adding pytorch quantized sparsified aware training --- .ci/spellcheck/.pyspelling.wordlist.txt | 5 +- notebooks/README.md | 1 + .../README.md | 24 + .../config.json | 36 + ...quantization-sparsity-aware-training.ipynb | 769 ++++++++++++++++++ 5 files changed, 834 insertions(+), 1 deletion(-) create mode 100644 notebooks/pytorch-quantization-sparsity-aware-training/README.md create mode 100644 notebooks/pytorch-quantization-sparsity-aware-training/config.json create mode 100644 notebooks/pytorch-quantization-sparsity-aware-training/pytorch-quantization-sparsity-aware-training.ipynb diff --git a/.ci/spellcheck/.pyspelling.wordlist.txt b/.ci/spellcheck/.pyspelling.wordlist.txt index d3ec4ee20e1..c9623a43454 100644 --- a/.ci/spellcheck/.pyspelling.wordlist.txt +++ b/.ci/spellcheck/.pyspelling.wordlist.txt @@ -323,6 +323,7 @@ JIT Joao JS JSON +json JT JuggernautXL Jupyter @@ -691,6 +692,8 @@ softvc SoftVC SOTA Sovits +sparsity +Sparisty sparsified sparsify spectrogram @@ -858,4 +861,4 @@ Zongyuan ZeroScope zeroscope zh -xformers \ No newline at end of file +xformers diff --git a/notebooks/README.md b/notebooks/README.md index 2bac66448b4..b3660142a28 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -256,6 +256,7 @@ - [From Training to Deployment with TensorFlow and OpenVINO™](./tensorflow-training-openvino/tensorflow-training-openvino.ipynb) - [Quantization Aware Training with NNCF, using TensorFlow Framework](./tensorflow-quantization-aware-training/tensorflow-quantization-aware-training.ipynb) - [Quantization Aware Training with NNCF, using PyTorch framework](./pytorch-quantization-aware-training/pytorch-quantization-aware-training.ipynb) +- [Quantization Sparsity Aware Training with NNCF, using PyTorch framework](./pytorch-quantization-sparsity-aware-training/pytorch-quantization-sparsity-aware-training.ipynb) ## Optimize diff --git a/notebooks/pytorch-quantization-sparsity-aware-training/README.md b/notebooks/pytorch-quantization-sparsity-aware-training/README.md new file mode 100644 index 00000000000..2acfefdf6d7 --- /dev/null +++ b/notebooks/pytorch-quantization-sparsity-aware-training/README.md @@ -0,0 +1,24 @@ +# Optimizing PyTorch models with Neural Network Compression Framework of OpenVINO™ by 8-bit quantization. + +This tutorial demonstrates how to use [NNCF](https://github.com/openvinotoolkit/nncf) 8-bit sparse quantization to optimize the +[PyTorch](https://pytorch.org/) model for inference with [OpenVINO Toolkit](https://docs.openvino.ai/). +For more advanced usage, refer to these [examples](https://github.com/openvinotoolkit/nncf/tree/develop/examples). + +This notebook is based on 'ImageNet training in PyTorch' [example](https://github.com/pytorch/examples/blob/master/imagenet/main.py). +This notebook uses a [ResNet-50](https://arxiv.org/abs/1512.03385) model with the +ImageNet dataset. + +## Notebook Contents + +This tutorial consists of the following steps: +* Transforming the original dense `FP32` model to sparse `INT8` +* Using fine-tuning to restore the accuracy. +* Exporting optimized and original models to OpenVINO +* Measuring and comparing the performance of the models. + +## Installation Instructions + +This is a self-contained example that relies solely on its own code and accompanying config.json file.
+We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. +For details, please refer to [Installation Guide](../../README.md). + diff --git a/notebooks/pytorch-quantization-sparsity-aware-training/config.json b/notebooks/pytorch-quantization-sparsity-aware-training/config.json new file mode 100644 index 00000000000..8313f9aad83 --- /dev/null +++ b/notebooks/pytorch-quantization-sparsity-aware-training/config.json @@ -0,0 +1,36 @@ +{ + "target_device": "CPU", + "input_info": { + "sample_size": [ + 1, + 3, + 224, + 224 + ] + }, + "compression": [ + { + "algorithm": "magnitude_sparsity", + "sparsity_init": 0, + "params": { + "schedule": "multistep", + "sparsity_freeze_epoch": 5, + "multistep_steps": [2], + "multistep_sparsity_levels": [0.3, 0.5] + }, + "ignored_scopes": [ + "ResNet/NNCFConv2d[conv1]/conv2d_0" + ] + }, + { + "algorithm": "quantization", + "initializer": { + "range": { + "num_init_samples": 10000 + } + }, + "ignored_scopes": [ + ] + } + ] +} \ No newline at end of file diff --git a/notebooks/pytorch-quantization-sparsity-aware-training/pytorch-quantization-sparsity-aware-training.ipynb b/notebooks/pytorch-quantization-sparsity-aware-training/pytorch-quantization-sparsity-aware-training.ipynb new file mode 100644 index 00000000000..d1973768dcf --- /dev/null +++ b/notebooks/pytorch-quantization-sparsity-aware-training/pytorch-quantization-sparsity-aware-training.ipynb @@ -0,0 +1,769 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "git68adWeq4l" + }, + "source": [ + "# Quantization-Sparsity Aware Training with NNCF, using PyTorch framework\n", + "\n", + "This notebook is based on [ImageNet training in PyTorch](https://github.com/pytorch/examples/blob/master/imagenet/main.py).\n", + "\n", + "The goal of this notebook is to demonstrate how to use the Neural Network Compression Framework [NNCF](https://github.com/openvinotoolkit/nncf) 8-bit quantization to optimize a PyTorch model for inference with OpenVINO Toolkit. The optimization process contains the following steps:\n", + "\n", + "* Transforming the original dense `FP32` model to sparse `INT8`\n", + "* Using fine-tuning to improve the accuracy.\n", + "* Exporting optimized and original models to OpenVINO IR\n", + "* Measuring and comparing the performance of models.\n", + "\n", + "For more advanced usage, refer to these [examples](https://github.com/openvinotoolkit/nncf/tree/develop/examples).\n", + "\n", + "This tutorial uses the ResNet-50 model with the ImageNet dataset. The dataset must be downloaded separately. To see ResNet models, visit [PyTorch hub](https://pytorch.org/hub/pytorch_vision_resnet/).\n", + "\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "#### Table of contents:\n", + "\n", + "- [Imports and Settings](#Imports-and-Settings)\n", + "- [Pre-train Floating-Point Model](#Pre-train-Floating-Point-Model)\n", + " - [Train Function](#Train-Function)\n", + " - [Validate Function](#Validate-Function)\n", + " - [Helpers](#Helpers)\n", + " - [Get a Pre-trained FP32 Model](#Get-a-Pre-trained-FP32-Model)\n", + "- [Create and Initialize Quantization](#Create-and-Initialize-Quantization)\n", + "- [Fine-tune the Compressed Model](#Fine-tune-the-Compressed-Model)\n", + "- [Export INT8 Sparse Model to OpenVINO IR](#Export-INT8-Model-to-OpenVINO-IR)\n", + "- [Benchmark Model Performance by Computing Inference Time](#Benchmark-Model-Performance-by-Computing-Inference-Time)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -q --extra-index-url https://download.pytorch.org/whl/cpu \"openvino>=2024.0.0\" \"torch\" \"torchvision\" \"tqdm\"\n", + "%pip install -q \"nncf>=2.9.0\"" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "6M1xndNu-z_2" + }, + "source": [ + "## Imports and Settings\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "On Windows, add the required C++ directories to the system PATH.\n", + "\n", + "Import NNCF and all auxiliary packages from your Python code.\n", + "Set a name for the model, and the image width and height that will be used for the network. Also define paths where PyTorch and OpenVINO IR versions of the models will be stored. \n", + "\n", + "> **NOTE**: All NNCF logging messages below ERROR level (INFO and WARNING) are disabled to simplify the tutorial. For production use, it is recommended to enable logging by removing ```set_log_level(logging.ERROR)```." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BtaM_i2mEB0z", + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "import time\n", + "import warnings # To disable warnings on export model\n", + "from pathlib import Path\n", + "\n", + "import torch\n", + "\n", + "import torch.nn as nn\n", + "import torch.nn.parallel\n", + "import torch.optim\n", + "import torch.utils.data\n", + "import torch.utils.data.distributed\n", + "import torchvision.datasets as datasets\n", + "import torchvision.transforms as transforms\n", + "import torchvision.models as models\n", + "\n", + "import openvino as ov\n", + "from torch.jit import TracerWarning\n", + "\n", + "torch.manual_seed(0)\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using {device} device\")\n", + "\n", + "MODEL_DIR = Path(\"model\")\n", + "OUTPUT_DIR = Path(\"output\")\n", + "# DATA_DIR = Path(\"...\") # Insert path to folder containing imagenet folder\n", + "# DATASET_DIR = DATA_DIR / \"imagenet\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Fetch `notebook_utils` module\n", + "import zipfile\n", + "import requests\n", + "\n", + "r = requests.get(\n", + " url=\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py\",\n", + ")\n", + "open(\"notebook_utils.py\", \"w\").write(r.text)\n", + "from notebook_utils import download_file\n", + "\n", + "DATA_DIR = Path(\"data\")\n", + "\n", + "\n", + "def download_tiny_imagenet_200(\n", + " data_dir: Path,\n", + " url=\"http://cs231n.stanford.edu/tiny-imagenet-200.zip\",\n", + " tarname=\"tiny-imagenet-200.zip\",\n", + "):\n", + " archive_path = data_dir / tarname\n", + " download_file(url, directory=data_dir, filename=tarname)\n", + " zip_ref = zipfile.ZipFile(archive_path, \"r\")\n", + " zip_ref.extractall(path=data_dir)\n", + " zip_ref.close()\n", + "\n", + "\n", + "def prepare_tiny_imagenet_200(dataset_dir: Path):\n", + " # Format validation set the same way as train set is formatted.\n", + " val_data_dir = dataset_dir / \"val\"\n", + " val_annotations_file = val_data_dir / \"val_annotations.txt\"\n", + " with open(val_annotations_file, \"r\") as f:\n", + " val_annotation_data = map(lambda line: line.split(\"\\t\")[:2], f.readlines())\n", + " val_images_dir = val_data_dir / \"images\"\n", + " for image_filename, image_label in val_annotation_data:\n", + " from_image_filepath = val_images_dir / image_filename\n", + " to_image_dir = val_data_dir / image_label\n", + " if not to_image_dir.exists():\n", + " to_image_dir.mkdir()\n", + " to_image_filepath = to_image_dir / image_filename\n", + " from_image_filepath.rename(to_image_filepath)\n", + " val_annotations_file.unlink()\n", + " val_images_dir.rmdir()\n", + "\n", + "\n", + "DATASET_DIR = DATA_DIR / \"tiny-imagenet-200\"\n", + "if not DATASET_DIR.exists():\n", + " download_tiny_imagenet_200(DATA_DIR)\n", + " prepare_tiny_imagenet_200(DATASET_DIR)\n", + " print(f\"Successfully downloaded and prepared dataset at: {DATASET_DIR}\")\n", + "\n", + "BASE_MODEL_NAME = \"resnet18\"\n", + "image_size = 64\n", + "\n", + "OUTPUT_DIR.mkdir(exist_ok=True)\n", + "MODEL_DIR.mkdir(exist_ok=True)\n", + "DATA_DIR.mkdir(exist_ok=True)\n", + "\n", + "# Paths where PyTorch and OpenVINO IR models will be stored.\n", + "fp32_pth_path = Path(MODEL_DIR / (BASE_MODEL_NAME + \"_fp32\")).with_suffix(\".pth\")\n", + "fp32_ir_path = fp32_pth_path.with_suffix(\".xml\")\n", + "int8_sparse_ir_path = Path(MODEL_DIR / (BASE_MODEL_NAME + \"_int8_sparse\")).with_suffix(\".xml\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "E01dMaR2_AFL" + }, + "source": [ + "### Train Function\n", + "[back to top ⬆️](#Table-of-contents:)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "940rcAIyiXml" + }, + "outputs": [], + "source": [ + "def train(train_loader, model, compression_ctrl, criterion, optimizer, epoch):\n", + " batch_time = AverageMeter(\"Time\", \":3.3f\")\n", + " losses = AverageMeter(\"Loss\", \":2.3f\")\n", + " top1 = AverageMeter(\"Acc@1\", \":2.2f\")\n", + " top5 = AverageMeter(\"Acc@5\", \":2.2f\")\n", + " progress = ProgressMeter(\n", + " len(train_loader),\n", + " [batch_time, losses, top1, top5],\n", + " prefix=\"Epoch:[{}]\".format(epoch),\n", + " )\n", + "\n", + " # Switch to train mode.\n", + " model.train()\n", + "\n", + " end = time.time()\n", + " for i, (images, target) in enumerate(train_loader):\n", + " images = images.to(device)\n", + " target = target.to(device)\n", + "\n", + " # Compute output.\n", + " output = model(images)\n", + " loss = criterion(output, target)\n", + "\n", + " # Measure accuracy and record loss.\n", + " acc1, acc5 = accuracy(output, target, topk=(1, 5))\n", + " losses.update(loss.item(), images.size(0))\n", + " top1.update(acc1[0], images.size(0))\n", + " top5.update(acc5[0], images.size(0))\n", + "\n", + " # Compute gradient and do opt step.\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # Measure elapsed time.\n", + " batch_time.update(time.time() - end)\n", + " end = time.time()\n", + "\n", + " print_frequency = 50\n", + " if i % print_frequency == 0:\n", + " progress.display(i)\n", + " compression_ctrl.scheduler.step()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "CoNr8qwm_El2" + }, + "source": [ + "### Validate Function\n", + "[back to top ⬆️](#Table-of-contents:)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "KgnugrWgicWC" + }, + "outputs": [], + "source": [ + "def validate(val_loader, model, criterion):\n", + " batch_time = AverageMeter(\"Time\", \":3.3f\")\n", + " losses = AverageMeter(\"Loss\", \":2.3f\")\n", + " top1 = AverageMeter(\"Acc@1\", \":2.2f\")\n", + " top5 = AverageMeter(\"Acc@5\", \":2.2f\")\n", + " progress = ProgressMeter(len(val_loader), [batch_time, losses, top1, top5], prefix=\"Test: \")\n", + "\n", + " # Switch to evaluate mode.\n", + " model.eval()\n", + "\n", + " with torch.no_grad():\n", + " end = time.time()\n", + " for i, (images, target) in enumerate(val_loader):\n", + " images = images.to(device)\n", + " target = target.to(device)\n", + "\n", + " # Compute output.\n", + " output = model(images)\n", + " loss = criterion(output, target)\n", + "\n", + " # Measure accuracy and record loss.\n", + " acc1, acc5 = accuracy(output, target, topk=(1, 5))\n", + " losses.update(loss.item(), images.size(0))\n", + " top1.update(acc1[0], images.size(0))\n", + " top5.update(acc5[0], images.size(0))\n", + "\n", + " # Measure elapsed time.\n", + " batch_time.update(time.time() - end)\n", + " end = time.time()\n", + "\n", + " print_frequency = 10\n", + " if i % print_frequency == 0:\n", + " progress.display(i)\n", + "\n", + " print(\" * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}\".format(top1=top1, top5=top5))\n", + " return top1.avg" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "qMnYsGo9_MA8" + }, + "source": [ + "### Helpers\n", + "[back to top ⬆️](#Table-of-contents:)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "R724tbxcidQE" + }, + "outputs": [], + "source": [ + "class AverageMeter(object):\n", + " \"\"\"Computes and stores the average and current value\"\"\"\n", + "\n", + " def __init__(self, name, fmt=\":f\"):\n", + " self.name = name\n", + " self.fmt = fmt\n", + " self.reset()\n", + "\n", + " def reset(self):\n", + " self.val = 0\n", + " self.avg = 0\n", + " self.sum = 0\n", + " self.count = 0\n", + "\n", + " def update(self, val, n=1):\n", + " self.val = val\n", + " self.sum += val * n\n", + " self.count += n\n", + " self.avg = self.sum / self.count\n", + "\n", + " def __str__(self):\n", + " fmtstr = \"{name} {val\" + self.fmt + \"} ({avg\" + self.fmt + \"})\"\n", + " return fmtstr.format(**self.__dict__)\n", + "\n", + "\n", + "class ProgressMeter(object):\n", + " def __init__(self, num_batches, meters, prefix=\"\"):\n", + " self.batch_fmtstr = self._get_batch_fmtstr(num_batches)\n", + " self.meters = meters\n", + " self.prefix = prefix\n", + "\n", + " def display(self, batch):\n", + " entries = [self.prefix + self.batch_fmtstr.format(batch)]\n", + " entries += [str(meter) for meter in self.meters]\n", + " print(\"\\t\".join(entries))\n", + "\n", + " def _get_batch_fmtstr(self, num_batches):\n", + " num_digits = len(str(num_batches // 1))\n", + " fmt = \"{:\" + str(num_digits) + \"d}\"\n", + " return \"[\" + fmt + \"/\" + fmt.format(num_batches) + \"]\"\n", + "\n", + "\n", + "def accuracy(output, target, topk=(1,)):\n", + " \"\"\"Computes the accuracy over the k top predictions for the specified values of k\"\"\"\n", + " with torch.no_grad():\n", + " maxk = max(topk)\n", + " batch_size = target.size(0)\n", + "\n", + " _, pred = output.topk(maxk, 1, True, True)\n", + " pred = pred.t()\n", + " correct = pred.eq(target.view(1, -1).expand_as(pred))\n", + "\n", + " res = []\n", + " for k in topk:\n", + " correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)\n", + " res.append(correct_k.mul_(100.0 / batch_size))\n", + " return res" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "kcSjyLBwiqBx", + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### Get a Pre-trained FP32 Model\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "А pre-trained floating-point model is a prerequisite for quantization. It can be obtained by tuning from scratch with the code below. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "avCsioUYIaL7", + "outputId": "183bdbb6-4016-463c-8d76-636a6b3a9778", + "tags": [], + "test_replace": { + "train_dataset,": "torch.utils.data.Subset(train_dataset, torch.arange(300)), ", + "val_dataset, ": "torch.utils.data.Subset(val_dataset, torch.arange(100)), " + } + }, + "outputs": [], + "source": [ + "num_classes = 1000\n", + "init_lr = 1e-4\n", + "batch_size = 128\n", + "epochs = 20\n", + "\n", + "# model = models.resnet50(pretrained=True)\n", + "model = models.resnet18(pretrained=True)\n", + "model.fc = nn.Linear(in_features=512, out_features=200, bias=True)\n", + "model.to(device)\n", + "\n", + "\n", + "# Data loading code.\n", + "train_dir = DATASET_DIR / \"train\"\n", + "val_dir = DATASET_DIR / \"val\"\n", + "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", + "\n", + "train_dataset = datasets.ImageFolder(\n", + " train_dir,\n", + " transforms.Compose(\n", + " [\n", + " transforms.Resize([image_size, image_size]),\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.ToTensor(),\n", + " normalize,\n", + " ]\n", + " ),\n", + ")\n", + "val_dataset = datasets.ImageFolder(\n", + " val_dir,\n", + " transforms.Compose(\n", + " [\n", + " transforms.Resize([256, 256]),\n", + " transforms.CenterCrop([image_size, image_size]),\n", + " transforms.ToTensor(),\n", + " normalize,\n", + " ]\n", + " ),\n", + ")\n", + "\n", + "train_loader = torch.utils.data.DataLoader(\n", + " train_dataset,\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + " num_workers=1,\n", + " pin_memory=True,\n", + " sampler=None,\n", + ")\n", + "\n", + "val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)\n", + "\n", + "# Define loss function (criterion) and optimizer.\n", + "criterion = nn.CrossEntropyLoss().to(device)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=init_lr)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pt_xNDDrJKsy", + "outputId": "0925c801-0585-4431-98c9-de0decc4ad27", + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "Export the `FP32` model to OpenVINO™ Intermediate Representation, to benchmark it in comparison with the `INT8` model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9d8LOmKut36x", + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "dummy_input = torch.randn(1, 3, image_size, image_size).to(device)\n", + "\n", + "ov_model = ov.convert_model(model, example_input=dummy_input, input=[1, 3, image_size, image_size])\n", + "ov.save_model(ov_model, fp32_ir_path, compress_to_fp16=False)\n", + "print(f\"FP32 model was exported to {fp32_ir_path}.\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "pobVoHEoKcYp" + }, + "source": [ + "## Create and Initialize Quantization and Sparsity Training\n", + "\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "NNCF enables compression-aware training by integrating into regular training pipelines. The framework is designed so that modifications to your original training code are minor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from nncf import NNCFConfig\n", + "from nncf.torch import create_compressed_model, register_default_init_args\n", + "\n", + "# load\n", + "nncf_config = NNCFConfig.from_json(\"config.json\")\n", + "nncf_config = register_default_init_args(nncf_config, train_loader)\n", + "\n", + "# Creating a compressed model\n", + "compression_ctrl, compressed_model = create_compressed_model(model, nncf_config)\n", + "compression_ctrl.scheduler.epoch_step()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Validate Compressed Model" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Evaluate the new model on the validation set after initialization of quantization and sparsity." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "acc1 = validate(val_loader, compressed_model, criterion)\n", + "print(f\"Accuracy of initialized sparse INT8 model: {acc1:.3f}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fine-tune the Compressed Model\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "At this step, a regular fine-tuning process is applied to further improve quantized model accuracy. Normally, several epochs of tuning are required with a small learning rate, the same that is usually used at the end of the training of the original model. No other changes in the training pipeline are required. Here is a simple example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "compression_lr = init_lr / 10\n", + "optimizer = torch.optim.Adam(compressed_model.parameters(), lr=compression_lr)\n", + "nr_epochs = 10\n", + "# Train for one epoch with NNCF.\n", + "print(\"Training\")\n", + "for epoch in range(nr_epochs):\n", + " compression_ctrl.scheduler.epoch_step()\n", + " train(train_loader, compressed_model, compression_ctrl, criterion, optimizer, epoch=epoch)\n", + "\n", + "# Evaluate on validation set after Quantization-Aware Training (QAT case).\n", + "print(\"Validating\")\n", + "acc1_int8_sparse = validate(val_loader, compressed_model, criterion)\n", + "\n", + "print(f\"Accuracy of tuned INT8 sparse model: {acc1_int8_sparse:.3f}\")\n", + "print(f\"Accuracy drop of tuned INT8 sparse model over pre-trained FP32 model: {acc1 - acc1_int8_sparse:.3f}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Export INT8 Sparse Model to OpenVINO IR\n", + "[back to top ⬆️](#Table-of-contents:)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "warnings.filterwarnings(\"ignore\", category=TracerWarning)\n", + "warnings.filterwarnings(\"ignore\", category=UserWarning)\n", + "# Export INT8 model to OpenVINO™ IR\n", + "ov_model = ov.convert_model(compressed_model, example_input=dummy_input, input=[1, 3, image_size, image_size])\n", + "ov.save_model(ov_model, int8_sparse_ir_path)\n", + "print(f\"INT8 sparse model exported to {int8_sparse_ir_path}.\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Benchmark Model Performance by Computing Inference Time\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "Finally, measure the inference performance of the `FP32` and `INT8` models, using [Benchmark Tool](https://docs.openvino.ai/2024/learn-openvino/openvino-samples/benchmark-tool.html) - inference performance measurement tool in OpenVINO. By default, Benchmark Tool runs inference for 60 seconds in asynchronous mode on CPU. It returns inference speed as latency (milliseconds per image) and throughput (frames per second) values.\n", + "\n", + "> **NOTE**: This notebook runs `benchmark_app` for 15 seconds to give a quick indication of performance. For more accurate performance, it is recommended to run `benchmark_app` in a terminal/command prompt after closing other applications. Run `benchmark_app -m model.xml -d CPU` to benchmark async inference on CPU for one minute. Change CPU to GPU to benchmark on GPU. Run `benchmark_app --help` to see an overview of all command-line options." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ipywidgets as widgets\n", + "\n", + "# Initialize OpenVINO runtime\n", + "core = ov.Core()\n", + "device = widgets.Dropdown(\n", + " options=core.available_devices,\n", + " value=\"CPU\",\n", + " description=\"Device:\",\n", + " disabled=False,\n", + ")\n", + "\n", + "device" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "def parse_benchmark_output(benchmark_output):\n", + " parsed_output = [line for line in benchmark_output if \"FPS\" in line]\n", + " print(*parsed_output, sep=\"\\n\")\n", + "\n", + "\n", + "print(\"Benchmark FP32 model (IR)\")\n", + "benchmark_output = ! benchmark_app -m $fp32_ir_path -d $device.value -api async -t 15\n", + "parse_benchmark_output(benchmark_output)\n", + "\n", + "print(\"Benchmark INT8 sparse model (IR)\")\n", + "benchmark_output = ! benchmark_app -m $int8_ir_path -d $device.value -api async -t 15\n", + "parse_benchmark_output(benchmark_output)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Show Device Information for reference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "core.get_property(device.value, \"FULL_DEVICE_NAME\")" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [ + "K5HPrY_d-7cV", + "E01dMaR2_AFL", + "qMnYsGo9_MA8", + "L0tH9KdwtHhV" + ], + "name": "NNCF Quantization PyTorch Demo (tiny-imagenet/resnet-18)", + "provenance": [] + }, + "kernelspec": { + "display_name": "gpu3_my_ov_notebooks", + "language": "python", + "name": "gpu3_my_ov_notebooks" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + }, + "openvino_notebooks": { + "imageUrl": "", + "tags": { + "categories": [ + "Model Training", + "Optimize" + ], + "libraries": [], + "other": [], + "tasks": [ + "Image Classification" + ] + } + }, + "vscode": { + "interpreter": { + "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" + } + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}