diff --git a/.gitignore b/.gitignore index e59267a..533d361 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,7 @@ __pycache__ *.jpeg *.gif *.weights -*.pth +#*.pth #*.onnx *.csv models/ diff --git a/MLpackages/basegun_ml/README.md b/MLpackages/basegun_ml/README.md index 15046d5..d693c8a 100644 --- a/MLpackages/basegun_ml/README.md +++ b/MLpackages/basegun_ml/README.md @@ -48,13 +48,45 @@ weapon_length,barrel_length,confidence_card=get_lengths(image_bytes)
  • confidence_card: it corresponds to the confidence score for the card prediction. A card is used as a reference for the measure module +
  • If the gun is not detected, the exception MissingGun is raised + +
  • If the card is not detected, the exception MissingCard is raised + +## Alarm Model detection +```Python +from basegun_ml.ocr import is_alarm_weapon +#After the import the model is already warmed-up for faster inference + +#Convert image to bytes +with open("test.jpg", "rb") as file: + image_bytes = file.read() + +#Prediction of the weapon typology +alarm_model = is_alarm_weapon(image_bytes, quality_check=True ) + + +``` +### Variables description +
  • alarm_model if the gun is one of the alarm model it returns "alarm weapon from model". If the gun has the PAK marking then alarm_model returns "alarm weapon PAK". + +
  • quality_check specify if the quality analysis is run before the text detection + +
  • If the image quality is too low, the exception LowQuality is raised + +
  • If no text is detected, the exception MissingText is raised + + + # Tests Tests are available for the classification task and the measure length task ``` pytest tests/test_classification.py pytest tests/test_measure.py +pytest tests/test_OCR.py ``` # Credits - This project uses the [Ultralytics Library](https://github.com/ultralytics/ultralytics) - The oriented bounding box detection is inspired from [this YOLOV5 implementation](https://github.com/hukaixuan19970627/yolov5_obb) +- The image quality analysis uses [Pyiqa](https://github.com/chaofengc/IQA-PyTorch) +- The OCR tasks are computed using [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR?tab=readme-ov-file) diff --git a/MLpackages/basegun_ml/basegun_ml/CNNIQA.pth b/MLpackages/basegun_ml/basegun_ml/CNNIQA.pth new file mode 100644 index 0000000..8f25f45 Binary files /dev/null and b/MLpackages/basegun_ml/basegun_ml/CNNIQA.pth differ diff --git a/MLpackages/basegun_ml/basegun_ml/PaddleModels/classification/._inference.pdmodel b/MLpackages/basegun_ml/basegun_ml/PaddleModels/classification/._inference.pdmodel new file mode 100644 index 0000000..87503bf Binary files /dev/null and b/MLpackages/basegun_ml/basegun_ml/PaddleModels/classification/._inference.pdmodel differ diff --git a/MLpackages/basegun_ml/basegun_ml/PaddleModels/classification/inference.pdiparams b/MLpackages/basegun_ml/basegun_ml/PaddleModels/classification/inference.pdiparams new file mode 100644 index 0000000..3449efb Binary files /dev/null and b/MLpackages/basegun_ml/basegun_ml/PaddleModels/classification/inference.pdiparams differ diff --git a/MLpackages/basegun_ml/basegun_ml/PaddleModels/classification/inference.pdiparams.info b/MLpackages/basegun_ml/basegun_ml/PaddleModels/classification/inference.pdiparams.info new file mode 100644 index 0000000..f31a157 Binary files /dev/null and b/MLpackages/basegun_ml/basegun_ml/PaddleModels/classification/inference.pdiparams.info differ diff --git a/MLpackages/basegun_ml/basegun_ml/PaddleModels/classification/inference.pdmodel b/MLpackages/basegun_ml/basegun_ml/PaddleModels/classification/inference.pdmodel new file mode 100644 index 0000000..b90c155 Binary files /dev/null and b/MLpackages/basegun_ml/basegun_ml/PaddleModels/classification/inference.pdmodel differ diff --git a/MLpackages/basegun_ml/basegun_ml/PaddleModels/detection/inference.pdiparams b/MLpackages/basegun_ml/basegun_ml/PaddleModels/detection/inference.pdiparams new file mode 100644 index 0000000..089594a Binary files /dev/null and b/MLpackages/basegun_ml/basegun_ml/PaddleModels/detection/inference.pdiparams differ diff --git a/MLpackages/basegun_ml/basegun_ml/PaddleModels/detection/inference.pdiparams.info b/MLpackages/basegun_ml/basegun_ml/PaddleModels/detection/inference.pdiparams.info new file mode 100644 index 0000000..082c148 Binary files /dev/null and b/MLpackages/basegun_ml/basegun_ml/PaddleModels/detection/inference.pdiparams.info differ diff --git a/MLpackages/basegun_ml/basegun_ml/PaddleModels/detection/inference.pdmodel b/MLpackages/basegun_ml/basegun_ml/PaddleModels/detection/inference.pdmodel new file mode 100644 index 0000000..223b861 Binary files /dev/null and b/MLpackages/basegun_ml/basegun_ml/PaddleModels/detection/inference.pdmodel differ diff --git a/MLpackages/basegun_ml/basegun_ml/PaddleModels/recognition/inference.pdiparams b/MLpackages/basegun_ml/basegun_ml/PaddleModels/recognition/inference.pdiparams new file mode 100644 index 0000000..4c3d9e9 Binary files /dev/null and b/MLpackages/basegun_ml/basegun_ml/PaddleModels/recognition/inference.pdiparams differ diff --git a/MLpackages/basegun_ml/basegun_ml/PaddleModels/recognition/inference.pdiparams.info b/MLpackages/basegun_ml/basegun_ml/PaddleModels/recognition/inference.pdiparams.info new file mode 100644 index 0000000..923329f Binary files /dev/null and b/MLpackages/basegun_ml/basegun_ml/PaddleModels/recognition/inference.pdiparams.info differ diff --git a/MLpackages/basegun_ml/basegun_ml/PaddleModels/recognition/inference.pdmodel b/MLpackages/basegun_ml/basegun_ml/PaddleModels/recognition/inference.pdmodel new file mode 100644 index 0000000..dccddcc Binary files /dev/null and b/MLpackages/basegun_ml/basegun_ml/PaddleModels/recognition/inference.pdmodel differ diff --git a/MLpackages/basegun_ml/basegun_ml/__init__.py b/MLpackages/basegun_ml/basegun_ml/__init__.py index 67cd075..c0ae840 100644 --- a/MLpackages/basegun_ml/basegun_ml/__init__.py +++ b/MLpackages/basegun_ml/basegun_ml/__init__.py @@ -1,7 +1,11 @@ from ultralytics import YOLO import os from basegun_ml.utils import load_models +from paddleocr import PaddleOCR +import torch +import pyiqa +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" this_dir, this_filename = os.path.split(__file__) @@ -13,3 +17,14 @@ os.path.join(this_dir, "./keypoints.pt"), os.path.join(this_dir, "warmup.jpg"), ) +model_ocr = PaddleOCR( + det_model_dir=os.path.join(this_dir, "PaddleModels/detection"), + rec_model_dir=os.path.join(this_dir, "PaddleModels/recognition"), + cls_model_dir=os.path.join(this_dir, "PaddleModels/classification"), + use_angle_cls=True, + show_log=False, +) +device = torch.device("cpu") +metric_iqa = pyiqa.create_metric( + "cnniqa", device=device, pretrained_model_path=os.path.join(this_dir, "CNNIQA.pth") +) diff --git a/MLpackages/basegun_ml/basegun_ml/classification.pt b/MLpackages/basegun_ml/basegun_ml/classification.pt index 4839a84..3b7319d 100644 Binary files a/MLpackages/basegun_ml/basegun_ml/classification.pt and b/MLpackages/basegun_ml/basegun_ml/classification.pt differ diff --git a/MLpackages/basegun_ml/basegun_ml/exceptions.py b/MLpackages/basegun_ml/basegun_ml/exceptions.py new file mode 100644 index 0000000..3235a9f --- /dev/null +++ b/MLpackages/basegun_ml/basegun_ml/exceptions.py @@ -0,0 +1,22 @@ +class MissingGun(Exception): + "Raised when the gun is not detected in the measure module" + + pass + + +class MissingCard(Exception): + "Raised when the card is not detected in the measure module" + + pass + + +class LowQuality(Exception): + "Raised when the image does not have a sufficient quality" + + pass + + +class MissingText(Exception): + "Raised when text is not detected in the reading module" + + pass diff --git a/MLpackages/basegun_ml/basegun_ml/measure.py b/MLpackages/basegun_ml/basegun_ml/measure.py index f4872b0..ba3d84f 100644 --- a/MLpackages/basegun_ml/basegun_ml/measure.py +++ b/MLpackages/basegun_ml/basegun_ml/measure.py @@ -2,6 +2,7 @@ import numpy as np from basegun_ml import model_card, model_keypoints from basegun_ml.utils import rotate, distanceCalculate, scalarproduct +from basegun_ml.exceptions import MissingCard, MissingGun def get_card(image, model): @@ -44,14 +45,14 @@ def get_lengths(imagebytes, draw=True, output_filename="result.jpg"): keypoints = get_keypoints(image, model_keypoints) if len(keypoints) == 0: - return (0, 0, 0) + raise MissingGun if keypoints[3][0] < keypoints[0][0]: # Weapon upside down image = cv2.rotate(image, cv2.ROTATE_180) keypoints = get_keypoints(image, model_keypoints) cards = get_card(image, model_card) if len(cards) == 0: - return (0, 0, 0) + raise MissingCard card = cards[0] confCard = card[8] CardP = distanceCalculate((card[0], card[1]), (card[4], card[5])) diff --git a/MLpackages/basegun_ml/basegun_ml/ocr.py b/MLpackages/basegun_ml/basegun_ml/ocr.py new file mode 100644 index 0000000..7ea15bf --- /dev/null +++ b/MLpackages/basegun_ml/basegun_ml/ocr.py @@ -0,0 +1,171 @@ +from basegun_ml import model_ocr, metric_iqa +from fuzzysearch import find_near_matches +import io +import PIL.Image as Image +import numpy as np +from basegun_ml.exceptions import MissingText, LowQuality + + +QUALITY_THRESHOLD = 0.50 + + +def get_text(results): + """extracts raw text from PaddleOCR output + Args: + results: raw result from PaddleOCR + + Returns: + text: A string with the text extracted from the image + """ + text = " " + for result in results: + text = text + result[1][0] + " " + return text.lower() + + +def is_in(word, phrase): + """Check if a word is in a word using fuzzysearch algorithm for a tolerance error + Args: + word: word seek in the text + phrase: text to explore + + Returns: + boolean: true if word is in phrase + """ + res = find_near_matches(word, phrase, max_l_dist=1) + return len(res) > 0 + + +def is_alarm_model(text): + """determine if the text is from an alarm model weapon image using rules defined with weapon experts + Args: + text: string of the extract text + + Returns: + boolean: true if the an alarm model is recognized + """ + # fuzzy search for words but exat value for model number + zoraki = ["r2", "925", "92s", "906", "2906", "918", "9o6", "29o6"] + + # Blow + if is_in("blow", text): + if any(word in text for word in ["f92", "c75"]): + return True + else: + return False + # Zoraki + elif is_in("zoraki", text): + if any(word in text for word in zoraki): + return True + else: + return False + # Kimar + elif is_in("kimar", text): + if is_in("auto", text): + if "75" in text: + return True + else: + return False + elif "911" in text: + return True + else: + return False + elif is_in("auto", text): + if any(word in text for word in ["92", "85"]): + return True + else: + return False + elif is_in("lady k", text): + return True + elif is_in("python", text): + return True + elif "pk4" in text: + return True + elif is_in( + "alarm", text + ): # Sur ce type de modèle il arrive que le mot kimar soit remplacé par le logo + if any(is_in(word, text) for word in ["competitive", "power"]): + return True + else: + return False + + else: + return False + + +def is_pak(text): + """determine if the text is from an alarm model weapon image with a PAK engraving + Args: + text: string of the extract text + + Returns: + boolean: true if the PAK engraving is recognized + """ + if any( + word in text + for word in [ + "pak ", + "p.a.k", + "pak.", + " pak", + "pa.k", + "p.ak", + "knall", + "P.A.Knall", + ] + ): + return True + else: + return False + + +def quality_eval(img): + """Evaluate the CNNIQA for image quality and compare it to a defined threshold + Args: + img: PIL image + + Returns: + boolean: true if the image has a good quality (score QUALITY_THRESHOLD + + +def is_alarm_weapon(image_bytes, quality_check=True): + """Global pipeline for determining if the weapon is an alarm gun using OCR + Args: + image_bytes: Bytes image from Basegun + + Returns: + string: User feedback on image quality or on alarm gun assessment + """ + + img = Image.open(io.BytesIO(image_bytes)) + if ( + quality_check + ): # possibilité ne pas prendre en compte la verification de qualité d'image + eval = quality_eval(img) + else: + eval = True + + if eval: + results = model_ocr.ocr(np.asarray(img), cls=True) + if ( + results != [None] + ): # The results with recongition and detection confidence below 0.5 are filtered by paddle, the thresholds values can be changed + text = get_text(results[0]) + if is_alarm_model(text): + return "alarm weapon from model" + elif is_pak(text): + return "alarm weapon PAK" + else: + return "Not an alarm weapon" + else: + raise MissingText + else: + raise LowQuality diff --git a/MLpackages/basegun_ml/pyproject.toml b/MLpackages/basegun_ml/pyproject.toml index e212f35..09fb23e 100644 --- a/MLpackages/basegun_ml/pyproject.toml +++ b/MLpackages/basegun_ml/pyproject.toml @@ -3,12 +3,12 @@ requires = ["setuptools"] build-backend = "setuptools.build_meta" [tool.setuptools.package-data] -basegun_ml = ["*.pt","*.jpg","*.onnx"] - +basegun_ml = ["*.pt","*.jpg","*.onnx","*pth"] +"basegun_ml.PaddleModels"=["**"] [project] name = "basegun_ml" -version = "1.0.1" +version = "2.0.0" authors = [ { name="aurelien martinez" }, ] @@ -20,11 +20,16 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "ultralytics", - "pillow", - "onnxruntime", - "opencv-python", - "numpy" + "ultralytics>=8.2.48", + "pillow>=7.1.2", + "onnxruntime>=1.9.0", + "opencv-python>=4.6.0", + "numpy>=1.21.6", + "fuzzysearch>=0.7.3", + "pyiqa==0.1.11", + "torch==2.3.1", + "paddleocr==2.7.3", + "paddlepaddle==2.6.1" ] license = {file = "LICENSE"} diff --git a/MLpackages/tests/test_OCR.py b/MLpackages/tests/test_OCR.py new file mode 100644 index 0000000..1b192fe --- /dev/null +++ b/MLpackages/tests/test_OCR.py @@ -0,0 +1,48 @@ +from basegun_ml.ocr import is_alarm_weapon +from basegun_ml.exceptions import LowQuality, MissingText +import os +import pytest + + +this_dir, this_filename = os.path.split(__file__) + + +def to_bytes(img): + with open(img, "rb") as file: + image_bytes = file.read() + return image_bytes + + +class TestMeasure: + def test_LowQuality(self): + with pytest.raises(LowQuality): + is_alarm_weapon( + to_bytes(this_dir + "/tests_images/test_ocr/bad_quality.JPG") + ) + + def test_NoText(self): + with pytest.raises(MissingText): + is_alarm_weapon(to_bytes(this_dir + "/tests_images/test_ocr/no_text.JPG")) + + def test_LowQualityBypass(self): + pred = is_alarm_weapon( + to_bytes(this_dir + "/tests_images/test_ocr/bad_quality.JPG"), + quality_check=False, + ) + assert pred == "Not an alarm weapon" + + def test_NotAlarm(self): + pred = is_alarm_weapon( + to_bytes(this_dir + "/tests_images/test_ocr/not_alarm.JPG") + ) + assert pred == "Not an alarm weapon" + + def test_PAK(self): + pred = is_alarm_weapon(to_bytes(this_dir + "/tests_images/test_ocr/PAK.JPG")) + assert pred == "alarm weapon PAK" + + def test_AlarmModel(self): + pred = is_alarm_weapon( + to_bytes(this_dir + "/tests_images/test_ocr/alarm_model.JPG") + ) + assert pred == "alarm weapon from model" diff --git a/MLpackages/tests/test_measure.py b/MLpackages/tests/test_measure.py index cf2167b..810baff 100644 --- a/MLpackages/tests/test_measure.py +++ b/MLpackages/tests/test_measure.py @@ -1,5 +1,7 @@ from basegun_ml.measure import get_lengths +from basegun_ml.exceptions import MissingCard, MissingGun import os +import pytest this_dir, this_filename = os.path.split(__file__) @@ -12,22 +14,23 @@ def to_bytes(img): def equalMarg(a, b, margin): - #measure if the predicted length is close enough to the true length + # measure if the predicted length is close enough to the true length return abs(b - a) < margin class TestMeasure: def test_Noweapon(self): - pred = get_lengths( - to_bytes(this_dir + "/tests_images/test_measure/noWeapon.JPG"), draw=False - ) - assert pred == (0, 0, 0) + with pytest.raises(MissingGun): + get_lengths( + to_bytes(this_dir + "/tests_images/test_measure/noWeapon.JPG"), + draw=False, + ) def test_NoCard(self): - pred = get_lengths( - to_bytes(this_dir + "/tests_images/test_measure/noCard.jpg"), draw=False - ) - assert pred == (0, 0, 0) + with pytest.raises(MissingCard): + get_lengths( + to_bytes(this_dir + "/tests_images/test_measure/noCard.jpg"), draw=False + ) def test_perfverrou(self): pred = get_lengths( diff --git a/MLpackages/tests/tests_images/test_ocr/PAK.JPG b/MLpackages/tests/tests_images/test_ocr/PAK.JPG new file mode 100644 index 0000000..7577934 Binary files /dev/null and b/MLpackages/tests/tests_images/test_ocr/PAK.JPG differ diff --git a/MLpackages/tests/tests_images/test_ocr/alarm_model.JPG b/MLpackages/tests/tests_images/test_ocr/alarm_model.JPG new file mode 100644 index 0000000..143e1fd Binary files /dev/null and b/MLpackages/tests/tests_images/test_ocr/alarm_model.JPG differ diff --git a/MLpackages/tests/tests_images/test_ocr/bad_quality.JPG b/MLpackages/tests/tests_images/test_ocr/bad_quality.JPG new file mode 100644 index 0000000..303c6ac Binary files /dev/null and b/MLpackages/tests/tests_images/test_ocr/bad_quality.JPG differ diff --git a/MLpackages/tests/tests_images/test_ocr/no_text.JPG b/MLpackages/tests/tests_images/test_ocr/no_text.JPG new file mode 100644 index 0000000..b6f9b53 Binary files /dev/null and b/MLpackages/tests/tests_images/test_ocr/no_text.JPG differ diff --git a/MLpackages/tests/tests_images/test_ocr/not_alarm.JPG b/MLpackages/tests/tests_images/test_ocr/not_alarm.JPG new file mode 100644 index 0000000..87c41f7 Binary files /dev/null and b/MLpackages/tests/tests_images/test_ocr/not_alarm.JPG differ