Skip to content

Simple CNN classification example using pytorch. Training, exporting from pytorch to onnx and inference both pytorch and onnxruntime.

Notifications You must be signed in to change notification settings

k2-gc/Simple-CNN-Example

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Simple-CNN-Example

Introduction

This repository aims at introducing how to train deep leaerning classification models with Pytorch, export to onnx and use it with onnxruntime taking MNIST dataset, which is famous for handwriting digit image, as an example. Generally, CNN model accepts 3channels(RGB) but MNIST has one channel. To deal with this, Custom MNIST Dataset class returns 3channels tensor inheriting "torchvision.dataset.MNIST" class.

Prerequisites

  • Docker
  • Docker compose
  • docker login nvcr.io
  • dGPU (Recommended)

How to train

Train with dGPU

docker compose -f docker-compose-gpu.yaml up -d
docker exec -it mnist_train /bin/bash
python train.py

Train with cpu

docker compose -f docker-compose.yaml up -d
docker exec -it mnist_train /bin/bash
python train.py

Export model from pytorch to onnx

After training, run command bellow.

python export.py

Run onnx with onnxruntime

python check_onnx_inferenc.py

The code above choose 3 sample images from MNIST dataset, infer them and show results of inference of pytorch model and onnx model.

About

Simple CNN classification example using pytorch. Training, exporting from pytorch to onnx and inference both pytorch and onnxruntime.

Topics

Resources

Stars

Watchers

Forks