Skip to content

Commit

Permalink
✨ Add FINNger code
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaeelaudibert committed May 26, 2021
0 parents commit ba2a733
Show file tree
Hide file tree
Showing 14 changed files with 653 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
model/*
!model/.gitkeep

data/*
!data/*.zip

__pycache__
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# FINNger

FINNger is a CNN intended to detect how many raised fingers you have through your webcam (or any image capturing device, actually). The final intention from this work is to create a mobile app where children can learn some basic arithmetics. This is mostly the code for the model generation, with a small proof-of-concept to check that it would really work.

More information about the work per se and the model can be found on the paper: _link unavailable at the moment_

## Installing

We need `Python3` to run this code. To install our library dependencies you can run `pip3 install -r requirements.txt`.

## Dataset

By default, we already have one of the used datasets available on this repository. You can run, from the root the command `./extract_dataset.sh` and the custom dataset will be available on `model/<dataset_name>`.

To download [koriakinp/fingers](https://www.kaggle.com/koryakinp/fingers) repository, refer to Kaggle website to understand how you can download the dataset.

## Model

A trained model is not available in the repository. However, on the releases tab we made the final model and optimizer state available for demonstration purposes.

## Results

As stated above, the full results can be found on the paper. However, here we have a small demonstration of the high accuracy of the trained model on the validation images. On this image, the row is the expected value, and the columns is the FINNger model output value.

![Correlation Matrix for our model](images/nn_detection_corr.png)


## Authors

- [Rafael Baldasso Audibert](https://www.rafaaudibert.dev)
- Vinicius Maschio
27 changes: 27 additions & 0 deletions calculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
class Calculator():
def __init__(self):
self.a = None
self.b = None

def add_number(self, number):
if self.a is None:
self.a = number
elif self.b is None:
self.b = number
else:
self.a = None
self.b = None

@property
def result(self):
if self.a is None or self.b is None:
return None

return self.a + self.b

def __str__(self):
a = "?" if self.a is None else str(self.a)
b = "?" if self.b is None else str(self.b)
c = "?" if self.a is None or self.b is None else str(self.result)

return f"{a} + {b} = {c}"
Binary file added data/custom_fingers.zip
Binary file not shown.
49 changes: 49 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
import numpy as np
from torch.utils.data import Dataset
import glob
import cv2
from tqdm import tqdm


def Identity(x): return x


class FINNgerDataset(Dataset):
"""Hand Images dataset available at https://www.kaggle.com/koryakinp/fingers."""

NUM_CLASSES = 6

def __init__(self, data_dir, transform=Identity):
"""
Args:
data_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied on a sample, by default is Identity.
"""
self.data_dir = data_dir
self.transform = transform

self.glob_path = glob.glob(data_dir)
self.dataset = []
for img_path in tqdm(self.glob_path, desc="Import data"):
# Images are in the format <randomname>_<class>.png and here we are parsing the number from the class characters
image_label = int(img_path[-6:-5])

image = cv2.imread(img_path)

self.dataset.append({'image': image, 'label': image_label})
self.dataset = np.array(self.dataset)

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()

# Return in a good format for testing
sample = self.dataset[idx]
return (
self.transform(sample['image']),
sample['label'],
)
87 changes: 87 additions & 0 deletions dataset_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from threading import Thread
import uuid

import cv2
import PySimpleGUI as sg
import click


key_pressed: bool = False
pictures_taken = 0


def detect_key_press():
global key_pressed
didnt_press_before = False

while True:
input("Press anything to take a screenshot")
key_pressed = True
if didnt_press_before:
print("Thank you!")
didnt_press_before = True


def save_image(image, identifier: str, path: str, size: int) -> None:
global pictures_taken

generated_uuid = uuid.uuid4()
full_name = f"{generated_uuid}_{identifier}.png"

resized_array = cv2.resize(image, (size, size))
black_and_white = cv2.cvtColor(resized_array, cv2.COLOR_RGB2GRAY)

full_path = path + full_name
cv2.imwrite(full_path, black_and_white)

pictures_taken += 1
print(f"Saved image {pictures_taken} to {full_path}")

print("Shape read back is", cv2.imread(full_path).shape)


@click.command()
@click.option("--identifier", required=True, help="The identifier appended to the end of the image")
@click.option("--path", required=True, help="The path where the photos will be saved on")
@click.option("--size", default=128, help="Size to resize the image to", show_default=True)
@click.option("--n_images", default=float("inf"), help="How many images we should generate", show_default=True)
def main(identifier: str, path: str, size: int, n_images: int):
global key_pressed

# Thread used to detect the key pressing
thread = Thread(target=detect_key_press)
thread.start()

window = sg.Window(
'Dataset Generator',
[[sg.Image(filename='', key='image')], ],
location=(800, 400),
)

cap = cv2.VideoCapture(0) # Setup the camera as a capture device
while True:
# get events for the window with 20ms max wait
event, _values = window.Read(timeout=20, timeout_key='timeout')
if event is None: # if user closed window, quit
break

_ret, image = cap.read()

# Update image in window
window_image = window.FindElement('image')
encoded_image = cv2.imencode('.png', image)[1].tobytes()
window_image.Update(data=encoded_image)

# This is handled in a different thread, responsible for detecting the key press
if key_pressed:
save_image(image, identifier, path, size)
key_pressed = False

if pictures_taken >= n_images:
print(
f"Finished generating {pictures_taken} images. Quitting application...")
break


if __name__ == "__main__":
main()
7 changes: 7 additions & 0 deletions default_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
DEFAULT_LEARNING_RATE = 0.0003
DEFAULT_WEIGHT_DECAY = 1e-4

DEFAULT_TRAIN_DATASET = "data/fingers/train/*.png"
DEFAULT_TEST_DATASET = "data/fingers/test/*.png"

DEFAULT_BATCH_SIZE = 8
2 changes: 2 additions & 0 deletions extract_dataset.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
unzip data/fingers.zip -d data
unzip data/custom_fingers.zip -d data
Binary file added images/nn_detection_corr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit ba2a733

Please sign in to comment.