Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AI] 최종 배포 버전 개발과 README 파일 수정 #6

Merged
merged 3 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified .DS_Store
Binary file not shown.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,6 @@ poetry.toml
pyrightconfig.json

# End of https://www.toptal.com/developers/gitignore/api/python
yolov8l-pose.pt
.gitignore
.DS_Store
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This repository is the artificial intelligence repository for the ReHab project. Artificial intelligence is a crucial element in the project, as it provides essential services and methods for guiding users in performing exercises. Through artificial intelligence, we offer guidance videos and provide users with a way to perform exercises. We evaluate how well the user is doing by measuring similarity through feature extraction and cosine similarity using the videos provided by the user.

We utilize pre-trained models for our system. The baseline model employs Posenet, and this choice might change based on considerations such as the trade-off between communication overhead and computation overhead.
We utilize pre-trained models for our system. The baseline model employs YOLOv8, and this choice might change based on considerations such as the trade-off between communication overhead and computation overhead.

Furthermore, we have implemented the patient and doctor counseling feature through WebRTC. You can check out this functionality in the [Backend](https://github.com/sync-without-async/Rehab-BackEnd) and [Frontend](https://github.com/sync-without-async/Rehab-FrontEnd). We've also added an AI capability that summarizes the counseling content. The original repository can be found at [Rehab-Audio](https://github.com/sync-without-async/Rehab-Audio). The feature development is complete, and we have migrated it to this repository.

Expand Down Expand Up @@ -75,6 +75,12 @@ The `--reload` parameter automatically restarts the server whenever there's a ch

## Our Model

The `torchvision.models` module in PyTorch provides various pre-trained and state-of-the-art model architectures. Since extracting human poses from images is crucial, we have used the Keypoint RCNN ResNet50 FPN-based model, which is capable of extracting keypoints. Understanding it as a structure comprising Keypoint R-CNN + ResNet50 FPN makes it easier to comprehend.
We have been using models from `torchvision.models`, specifically opting for the KeyPoint Mask R-CNN ResNet101 Backbone model, which allows keypoint extraction. However, despite the model's good performance, we found that it takes a long time and is challenging to reduce latency for the user. Additionally, we determined that uploading and processing on the Nvidia Jetson Nano is difficult due to the high FLOPs (or MACs).

As mentioned in the [official documentation](https://pytorch.org/vision/stable/models/generated/torchvision.models.detection.keypointrcnn_resnet50_fpn.html#torchvision.models.detection.keypointrcnn_resnet50_fpn), the default weights are from a model trained on the COCO Dataset v1. This model has more parameters compared to legacy models, though GFLOPs are reduced, resulting in an improved performance model.
To address these challenges and improve performance, we decided to apply a suitable YOLOv8 model to the Edge Device.

[YOLOv8](https://docs.ultralytics.com/) is one of the state-of-the-art (SOTA) models that demonstrates the lowest latency among existing YOLO versions while delivering the highest performance. We judged it to be faster and more accurate than the previously used KeyPoint Mask R-CNN ResNet101 Backbone model. It is trained with the same COCOv1 model as before, containing 17 keypoints.

## Similarity

For similarity measurement, we utilized the example project "just_dance" within the MMPose Library. This project is a kind of game where real-time similarity is assessed between a pre-uploaded dance video and the user's dance video captured through a webcam. Users receive scores based on the similarity, creating a dance scoring game. While this project shares a strikingly similar nature with ours, it differs in being developed as a Jupyter Notebook, allowing testing locally. You can explore this project [here](https://github.com/open-mmlab/mmpose/tree/main/projects/just_dance).
30 changes: 24 additions & 6 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from fastapi.responses import PlainTextResponse
from fastapi.exceptions import RequestValidationError

from pydub import AudioSegment

import pandas as pd
import torch

Expand All @@ -23,7 +25,7 @@
EXTRACTOR_THRESHOLD = 0.85

app = FastAPI()
extractor = SkeletonExtractor(pretrained_bool=True, number_of_keypoints=17, device='mps')
extractor = SkeletonExtractor(pretrained_bool=True, number_of_keypoints=17, device='cuda')
preprocessor = DataPreprocessing()
metrics = Metrics()
mmpose_similarity = MMPoseStyleSimilarty()
Expand All @@ -48,6 +50,10 @@ async def validation_exception_handler(request, exc):
async def generic_exception_handler(request, exc):
return PlainTextResponse(str(exc), status_code=500)

@app.get("/")
async def root():
return {"message": "rehab_ai_server_api_success"}

@app.post("/videoRegister")
async def registerVideo(
video_file: UploadFile = File(...)
Expand Down Expand Up @@ -174,7 +180,7 @@ async def getMetricsConsumer(
return {"metrics": score}

@app.get("/getSummary")
async def getSummary(ano: int = Form(),
async def getSummary(ano: int,
background_tasks: BackgroundTasks = BackgroundTasks()
):
connector, cursor = database_connector(database_secret_path="secret_key.json")
Expand All @@ -198,8 +204,9 @@ async def getSummary(ano: int = Form(),
doctor_audio = requests.get(doctor_audio_url).content
patient_audio = requests.get(patient_audio_url).content

with open("doctor.wav", "wb") as f: f.write(doctor_audio)
with open("patient.wav", "wb") as f: f.write(patient_audio)
with open("doctor.wav", "+wb") as f: f.write(doctor_audio)
with open("patient.wav", "+wb") as f: f.write(patient_audio)

doctor_audio, doc_fs = den.load_audio("doctor.wav")
patient_audio, pat_fs = den.load_audio("patient.wav")

Expand All @@ -209,12 +216,13 @@ async def getSummary(ano: int = Form(),
doctor_audio=doctor_audio,
patient_audio=patient_audio,
doc_fs=doc_fs,
pat_fs=pat_fs
pat_fs=pat_fs,
)

return True

except Exception as e:
logging.error("[SUMMARY_MODULE] Error occured while getting audio from database.")
logging.error(e)
return False

Expand Down Expand Up @@ -261,7 +269,7 @@ async def _do_summary(
summarized = summary.summarize(
doctor_content=doctor_transcript,
patient_content=patient_transcript,
max_tokens=1024,
max_tokens=700,
verbose=True,
)

Expand All @@ -281,3 +289,13 @@ async def _do_summary(
logging.info("[SUMMARY_MODULE] Summary has been saved in the database.")
logging.info("[SUMMARY_MODULE] Summary: ")
logging.info(summarized)

def _hide_seek(obj):
class _wrapper:
def __init__(self, obj):
self.obj = obj

def read(self, n):
return self.obj.read(n)

return _wrapper(obj)
8 changes: 6 additions & 2 deletions denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ def _check_parallel_device_list():
return device_list

def load_audio(
path: str = None
path: str = None,
verbose: bool = False
):
if path is None: raise ValueError(f"path argument is required. Excepted: str, but got {path}")
if verbose: logging.info(f"Loading audio from {path}...")

audio, sample_rate = torchaudio.load(path, format="wav")

audio, sample_rate = torchaudio.load(path)
if verbose: logging.info("Done!")

return audio, sample_rate

Expand Down
3 changes: 1 addition & 2 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,13 @@ def extract(self, video_tensor: cv2.VideoCapture, score_threshold: float = 0.93,
start_time = time.time()

# Cropping the image (YoloV8)
get_bounding_box = self.yolov8_model.predict(frame_from_video)[0].boxes.xyxy[0]
get_bounding_box = self.yolov8_model.predict(frame_from_video)[0].boxes.xyxy[0].cpu().numpy()
print(get_bounding_box)

left_top = (get_bounding_box[0], get_bounding_box[1])
right_top = (get_bounding_box[2], get_bounding_box[1])
left_bottom = (get_bounding_box[0], get_bounding_box[3])
right_bottom = (get_bounding_box[2], get_bounding_box[3])
cropping_pts = np.array([left_top, right_top, left_bottom, right_bottom])

cropped_image = frame_from_video[int(left_top[1]):int(right_bottom[1]), int(left_top[0]):int(right_bottom[0])]
cropped_image = cv2.resize(cropped_image, (256, 512))
Expand Down
14 changes: 1 addition & 13 deletions speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __feature_extractor(

return transcription

'''
def speech_to_text_whisper(
pretrained_model_name_or_path: str = None,
audio: np.ndarray = None,
Expand All @@ -63,18 +62,7 @@ def speech_to_text_whisper(
logging.info(f"batchsize argument is not provided. Using default value: {batchsize}")

if verbose: logging.info("Loading model...")
processor = WhisperProcessor.from_pretrained(pretrained_model_name_or_path)
model = WhisperForConditionalGeneration.from_pretrained(pretrained_model_name_or_path).to(device)

return __feature_extractor(
model=model,
processor=processor,
audio=audio,
audio_sample_rate=audio_sample_rate,
device=device,
verbose=verbose,
)
'''


def speech_to_text(
processor_pretrained_argument: str = None,
Expand Down
4 changes: 2 additions & 2 deletions summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
with open("secret_key.json") as f: secret_key = json.load(f)

openai.api_key = secret_key['OpenAI']['API_KEY']
MODEL_NAME = "gpt-3.5-turbo"
MODEL_NAME = "gpt-4-1106-preview"

def _get_prompt(
doctor_content: str = None,
patient_content: str = None) -> list:
if doctor_content is None: raise ValueError(f"doctor_content argument is required. Excepted: str, but got {doctor_content}")

system_prompt = f"당신은 문서 정리를 하는 서기 입니다. 그 중에서도 두 사람의 대화 내용을 듣고 어떤 대화인지 정리 요약하는 서기 입니다. 두 사람의 대화 내용을 당신에게 전달할 것입니다. 한명은 의사, 한명은 환자입니다. 의사는 DOC:<TEXT HERE>로 드릴 것이며 환자는 PATIENT:<TEXT HERE>로 드릴 예정입니다. 대화 내용에서는 시간적 특성이 배제 되어 있습니다."
assistant_prompt = "당신은 별 다른 시간 인덱스가 없더라도 내용을 파악하고 이해하셔야 합니다. 요약 정리는 Markdown 문서 형식으로 정리가 되어야합니다. #주요대화내용, #의사요점, #환자요점 으로 정리합니다. #주요대화내용은 대화의 주제로 환자가 어디가 아파하는지, 어떤 도움이 필요한지 정리합니다. #의사요점은 의사가 말한 내용 중에서 중심적으로 봐야하거나 중요하게 여긴 내용을 정리합니다. #환자요점은 환자가 말한 증상 혹은 현재 상태에 대해 내용을 정리합니다."
assistant_prompt = "당신은 별 다른 시간 정보가 없더라도 내용을 파악하고 이해하셔야 합니다. 요약 정리는 Markdown 문서 형식으로 정리가 되어야합니다. #주요대화내용, #의사요점, #환자요점 으로 정리합니다. #주요대화내용은 대화의 주제로 환자가 어디가 아파하는지, 어떤 도움이 필요한지 정리합니다. #의사요점은 의사가 말한 내용 중에서 중심적으로 봐야하거나 중요하게 여긴 내용을 정리합니다. #환자요점은 환자가 말한 증상 혹은 현재 상태에 대해 내용을 정리합니다."
user_prompt = f"DOC:<{doctor_content}> PATIENT:<{patient_content}>"

return [
Expand Down