Skip to content

Commit

Permalink
Merge pull request #18 from dnum-mi/feat-OCR
Browse files Browse the repository at this point in the history
Feat ocr
  • Loading branch information
AurelienmartW authored Aug 8, 2024
2 parents 089f3ef + fe7eb7c commit 821d89e
Show file tree
Hide file tree
Showing 26 changed files with 317 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ __pycache__
*.jpeg
*.gif
*.weights
*.pth
#*.pth
#*.onnx
*.csv
models/
Expand Down
32 changes: 32 additions & 0 deletions MLpackages/basegun_ml/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,45 @@ weapon_length,barrel_length,confidence_card=get_lengths(image_bytes)

<li> <b>confidence_card</b>: it corresponds to the confidence score for the card prediction. A card is used as a reference for the measure module

<li> If the gun is not detected, the exception <b>MissingGun</b> is raised

<li> If the card is not detected, the exception <b>MissingCard</b> is raised

## Alarm Model detection
```Python
from basegun_ml.ocr import is_alarm_weapon
#After the import the model is already warmed-up for faster inference

#Convert image to bytes
with open("test.jpg", "rb") as file:
image_bytes = file.read()

#Prediction of the weapon typology
alarm_model = is_alarm_weapon(image_bytes, quality_check=True )


```
### Variables description
<li> <b>alarm_model</b> if the gun is one of the alarm model it returns "alarm weapon from model". If the gun has the PAK marking then alarm_model returns "alarm weapon PAK".

<li> <b>quality_check</b> specify if the quality analysis is run before the text detection

<li> If the image quality is too low, the exception <b>LowQuality</b> is raised

<li> If no text is detected, the exception <b>MissingText</b> is raised



# Tests
Tests are available for the classification task and the measure length task
```
pytest tests/test_classification.py
pytest tests/test_measure.py
pytest tests/test_OCR.py
```
# Credits

- This project uses the [Ultralytics Library](https://github.com/ultralytics/ultralytics)
- The oriented bounding box detection is inspired from [this YOLOV5 implementation](https://github.com/hukaixuan19970627/yolov5_obb)
- The image quality analysis uses [Pyiqa](https://github.com/chaofengc/IQA-PyTorch)
- The OCR tasks are computed using [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR?tab=readme-ov-file)
Binary file added MLpackages/basegun_ml/basegun_ml/CNNIQA.pth
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
15 changes: 15 additions & 0 deletions MLpackages/basegun_ml/basegun_ml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from ultralytics import YOLO
import os
from basegun_ml.utils import load_models
from paddleocr import PaddleOCR
import torch
import pyiqa

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

this_dir, this_filename = os.path.split(__file__)

Expand All @@ -13,3 +17,14 @@
os.path.join(this_dir, "./keypoints.pt"),
os.path.join(this_dir, "warmup.jpg"),
)
model_ocr = PaddleOCR(
det_model_dir=os.path.join(this_dir, "PaddleModels/detection"),
rec_model_dir=os.path.join(this_dir, "PaddleModels/recognition"),
cls_model_dir=os.path.join(this_dir, "PaddleModels/classification"),
use_angle_cls=True,
show_log=False,
)
device = torch.device("cpu")
metric_iqa = pyiqa.create_metric(
"cnniqa", device=device, pretrained_model_path=os.path.join(this_dir, "CNNIQA.pth")
)
Binary file modified MLpackages/basegun_ml/basegun_ml/classification.pt
Binary file not shown.
22 changes: 22 additions & 0 deletions MLpackages/basegun_ml/basegun_ml/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
class MissingGun(Exception):
"Raised when the gun is not detected in the measure module"

pass


class MissingCard(Exception):
"Raised when the card is not detected in the measure module"

pass


class LowQuality(Exception):
"Raised when the image does not have a sufficient quality"

pass


class MissingText(Exception):
"Raised when text is not detected in the reading module"

pass
5 changes: 3 additions & 2 deletions MLpackages/basegun_ml/basegun_ml/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from basegun_ml import model_card, model_keypoints
from basegun_ml.utils import rotate, distanceCalculate, scalarproduct
from basegun_ml.exceptions import MissingCard, MissingGun


def get_card(image, model):
Expand Down Expand Up @@ -44,14 +45,14 @@ def get_lengths(imagebytes, draw=True, output_filename="result.jpg"):

keypoints = get_keypoints(image, model_keypoints)
if len(keypoints) == 0:
return (0, 0, 0)
raise MissingGun
if keypoints[3][0] < keypoints[0][0]: # Weapon upside down
image = cv2.rotate(image, cv2.ROTATE_180)
keypoints = get_keypoints(image, model_keypoints)

cards = get_card(image, model_card)
if len(cards) == 0:
return (0, 0, 0)
raise MissingCard
card = cards[0]
confCard = card[8]
CardP = distanceCalculate((card[0], card[1]), (card[4], card[5]))
Expand Down
171 changes: 171 additions & 0 deletions MLpackages/basegun_ml/basegun_ml/ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from basegun_ml import model_ocr, metric_iqa
from fuzzysearch import find_near_matches
import io
import PIL.Image as Image
import numpy as np
from basegun_ml.exceptions import MissingText, LowQuality


QUALITY_THRESHOLD = 0.50


def get_text(results):
"""extracts raw text from PaddleOCR output
Args:
results: raw result from PaddleOCR
Returns:
text: A string with the text extracted from the image
"""
text = " "
for result in results:
text = text + result[1][0] + " "
return text.lower()


def is_in(word, phrase):
"""Check if a word is in a word using fuzzysearch algorithm for a tolerance error
Args:
word: word seek in the text
phrase: text to explore
Returns:
boolean: true if word is in phrase
"""
res = find_near_matches(word, phrase, max_l_dist=1)
return len(res) > 0


def is_alarm_model(text):
"""determine if the text is from an alarm model weapon image using rules defined with weapon experts
Args:
text: string of the extract text
Returns:
boolean: true if the an alarm model is recognized
"""
# fuzzy search for words but exat value for model number
zoraki = ["r2", "925", "92s", "906", "2906", "918", "9o6", "29o6"]

# Blow
if is_in("blow", text):
if any(word in text for word in ["f92", "c75"]):
return True
else:
return False
# Zoraki
elif is_in("zoraki", text):
if any(word in text for word in zoraki):
return True
else:
return False
# Kimar
elif is_in("kimar", text):
if is_in("auto", text):
if "75" in text:
return True
else:
return False
elif "911" in text:
return True
else:
return False
elif is_in("auto", text):
if any(word in text for word in ["92", "85"]):
return True
else:
return False
elif is_in("lady k", text):
return True
elif is_in("python", text):
return True
elif "pk4" in text:
return True
elif is_in(
"alarm", text
): # Sur ce type de modèle il arrive que le mot kimar soit remplacé par le logo
if any(is_in(word, text) for word in ["competitive", "power"]):
return True
else:
return False

else:
return False


def is_pak(text):
"""determine if the text is from an alarm model weapon image with a PAK engraving
Args:
text: string of the extract text
Returns:
boolean: true if the PAK engraving is recognized
"""
if any(
word in text
for word in [
"pak ",
"p.a.k",
"pak.",
" pak",
"pa.k",
"p.ak",
"knall",
"P.A.Knall",
]
):
return True
else:
return False


def quality_eval(img):
"""Evaluate the CNNIQA for image quality and compare it to a defined threshold
Args:
img: PIL image
Returns:
boolean: true if the image has a good quality (score<threshold)
"""
width, height = img.size
ratio = 640 / width
newsize = (640, int(height * ratio))
im1 = img.resize(newsize)
res = metric_iqa(im1)
print(res)
return res > QUALITY_THRESHOLD


def is_alarm_weapon(image_bytes, quality_check=True):
"""Global pipeline for determining if the weapon is an alarm gun using OCR
Args:
image_bytes: Bytes image from Basegun
Returns:
string: User feedback on image quality or on alarm gun assessment
"""

img = Image.open(io.BytesIO(image_bytes))
if (
quality_check
): # possibilité ne pas prendre en compte la verification de qualité d'image
eval = quality_eval(img)
else:
eval = True

if eval:
results = model_ocr.ocr(np.asarray(img), cls=True)
if (
results != [None]
): # The results with recongition and detection confidence below 0.5 are filtered by paddle, the thresholds values can be changed
text = get_text(results[0])
if is_alarm_model(text):
return "alarm weapon from model"
elif is_pak(text):
return "alarm weapon PAK"
else:
return "Not an alarm weapon"
else:
raise MissingText
else:
raise LowQuality
21 changes: 13 additions & 8 deletions MLpackages/basegun_ml/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[tool.setuptools.package-data]
basegun_ml = ["*.pt","*.jpg","*.onnx"]

basegun_ml = ["*.pt","*.jpg","*.onnx","*pth"]
"basegun_ml.PaddleModels"=["**"]

[project]
name = "basegun_ml"
version = "1.0.1"
version = "2.0.0"
authors = [
{ name="aurelien martinez" },
]
Expand All @@ -20,11 +20,16 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"ultralytics",
"pillow",
"onnxruntime",
"opencv-python",
"numpy"
"ultralytics>=8.2.48",
"pillow>=7.1.2",
"onnxruntime>=1.9.0",
"opencv-python>=4.6.0",
"numpy>=1.21.6",
"fuzzysearch>=0.7.3",
"pyiqa==0.1.11",
"torch==2.3.1",
"paddleocr==2.7.3",
"paddlepaddle==2.6.1"
]
license = {file = "LICENSE"}

Expand Down
48 changes: 48 additions & 0 deletions MLpackages/tests/test_OCR.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from basegun_ml.ocr import is_alarm_weapon
from basegun_ml.exceptions import LowQuality, MissingText
import os
import pytest


this_dir, this_filename = os.path.split(__file__)


def to_bytes(img):
with open(img, "rb") as file:
image_bytes = file.read()
return image_bytes


class TestMeasure:
def test_LowQuality(self):
with pytest.raises(LowQuality):
is_alarm_weapon(
to_bytes(this_dir + "/tests_images/test_ocr/bad_quality.JPG")
)

def test_NoText(self):
with pytest.raises(MissingText):
is_alarm_weapon(to_bytes(this_dir + "/tests_images/test_ocr/no_text.JPG"))

def test_LowQualityBypass(self):
pred = is_alarm_weapon(
to_bytes(this_dir + "/tests_images/test_ocr/bad_quality.JPG"),
quality_check=False,
)
assert pred == "Not an alarm weapon"

def test_NotAlarm(self):
pred = is_alarm_weapon(
to_bytes(this_dir + "/tests_images/test_ocr/not_alarm.JPG")
)
assert pred == "Not an alarm weapon"

def test_PAK(self):
pred = is_alarm_weapon(to_bytes(this_dir + "/tests_images/test_ocr/PAK.JPG"))
assert pred == "alarm weapon PAK"

def test_AlarmModel(self):
pred = is_alarm_weapon(
to_bytes(this_dir + "/tests_images/test_ocr/alarm_model.JPG")
)
assert pred == "alarm weapon from model"
Loading

0 comments on commit 821d89e

Please sign in to comment.