Что за проект? Что он делает?
Detect Writed Number - задача по распознаванию рукописных цифр. В общем смысле ее можно разбить на две части: обучение и сохранение модели с настроенными весами и графическое окно для рисования цифры.
Таких проектов миллион. Не изобретай велосипед и используй Feature Extraction.
Поскольку подобных задач в сети куча, то в каком-то смысле ее стоит рассматривать с точки зрения тренировки и опыта. Pet project.
Input:
Output:
[INFO] User painted '7'
Основная часть проекта. Вкратце о том, что происходит в предыдущей части:
- Обучение модели на датасете MNIST;
- Тестирование модели;
- Сохранение модели в формате .pt
torch.save(mnist_net, 'mnist_full_model.pt')
.
...Последний пункт плавно перетекает в текущую часть проекта. В данной части реализация окна для рисования, распознавания и сохранения нарисованного изображения.
Поскольку стоит задача распознавания рукописных цифр, то для нее отлично подойдет MNIST-датасет, имеющий 60'000 тренировочных и 10'000 тестовых одноканальных изображений с расширением 28x28. Тренировочный датасет поделим на train и val в соотрошении 9:1 соответственно. Благодаря тому, что MNIST-датасет сбалансированный(все классы имеют одинаковое количество фич), то можно использовать acc как основную метрику. Именно по ней будем ориентироваться и выберем лучшие веса при лучшем accuracy.
Архитектура использует несколько идущих подряд блоков из conv-слоя батч-нормализации, LeakyReLU и max-пуллинга, в конце применяется несколько линейных слоев с ReLU. Подробнее в файле MNISTNet.py.
Хоть и можно было воспользоваться методом Feature Extraction, разморозить последний слой и дообучить тот же самый ResNet. Но было принято решение реализовать собственноручно архитектуру нейросети.
Путем длительного подбора оптимальных параметров было установлено, что нейросеть лучше всего обучается при:
- оптимизатор модели:
learning_rate=5e-4
- размер батча:
batch_size=50
- Conv2d:
kernel_size=3
иpadding=1(и еще 0)
- MaxPool2d:
kernel_size=2
иstride=2
Тренировка и валидация проходят в однй функции - train, сначала проходит эпоха для обучения, а потом - для валидации. На каждой фазе вычисляются такие метрики, как acc, loss, precision, recall, F-мера и macro F-мера. Пример:
...
----------
Epoch 39/40
Phase: train; Loss: 0.0027, Acc: 0.9999, Pre: 0.9993, Rec: 0.9993, macro-avr F1: 0.9993, avr F1: 0.9993
Phase: val; Loss: 0.0546, Acc: 0.9985, Pre: 0.9923, Rec: 0.9922, macro-avr F1: 0.9922, avr F1: 0.9922
Epoch time = 0:00:31.609002
----------
...
В процессе обучения запоминаются веса, при которых был лучший accuracy на валидационной выборке.
В функции test вычисляются такие же метрики, как и в train. Еще добавляется коллекция maximum_class_probabilities для вычисления доверительного порога для нейросети. Поскольку выход не нормируется софтмаксом, то можно будет хранить выход из нейросети, когда predict==label. После прохождения по всей тестовой выборке, находим доверительный интервал с alpha=0.95.
После расчета доверительного интервала выбираем минимальное значение и сохраняем его с именем trusted_threshold
в файл outputs/data.json
Данная часть реализована в файле processing.py. Обученная модель подгружается и используется для анализа собственно составленного небольшого датасета в папке inputs/. Эта папка содержит 2 подкаталога:
Содержит лейблы для сравнения и можно посчитать метрики.
Не содержит лейблы, разделена на часть, где есть цифры, и где их нет, но есть похожие на цифры предметы.
После выполнения данной процедуры будут сформированы файлы outputs/test.csv и outputs/check.scv с полезной информацией, которую выдала нейросеть.
Пример:
Пример:
Поскольку вторая часть написана на С++, то необходимо подружить обученную модель на python с оконным приложением на С++. Это можно сделать благодаря библиотеке TorchScript
, которая скриптует/трассирует модель. И в таком виде обученную модель можно считать кросс-языковой.
Модель дает accuracy=0.9990 на тестовой выборке, что является довольно неплохим результатом для задачи MNIST-датасета.
Скрин изменения acc и loss:
Как видно, после 22-й эпохи начинает расти loss. Значит, модель начала переобучаться, поэтому стоит приостановить обучение на этом этапе.
- Поддержка расширенных настроек через конфигурационный файл.
- Логгирование полезной информации.
- Работа из-под коробки. Скачали, собрали, запустили, нарисовали.
- Поддержка скриптования предобученных классических нейросетей для сравнения. Feature Extraction.