Skip to content

Commit

Permalink
Merge pull request #2299 from coronasafe/sainak/fix/atomicity-of-cons…
Browse files Browse the repository at this point in the history
…ultations

Fix race condition in consultation creation
  • Loading branch information
vigneshhari authored Aug 19, 2024
2 parents e95d98d + af9da5b commit 72f9e9a
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 118 deletions.
230 changes: 115 additions & 115 deletions care/facility/api/serializers/patient_consultation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
UserBaseMinimumSerializer,
)
from care.users.models import User
from care.utils.lock import Lock
from care.utils.notification_handler import NotificationGenerator
from care.utils.queryset.facility import get_home_facility_queryset
from care.utils.serializer.external_id_field import ExternalIdSerializerField
Expand Down Expand Up @@ -191,6 +192,9 @@ def get_discharge_prn_prescription(self, consultation):
dosage_type=PrescriptionDosageType.PRN.value,
).values()

def _lock_key(self, patient_id):
return f"patient_consultation__patient_registration__{patient_id}"

class Meta:
model = PatientConsultation
read_only_fields = TIMESTAMP_FIELDS + (
Expand Down Expand Up @@ -353,149 +357,145 @@ def create(self, validated_data):

create_diagnosis = validated_data.pop("create_diagnoses")
create_symptoms = validated_data.pop("create_symptoms")
action = -1
review_interval = -1
if "action" in validated_data:
action = validated_data.pop("action")
if "review_interval" in validated_data:
review_interval = validated_data.pop("review_interval")

action = validated_data.pop("action", -1)
review_interval = validated_data.get("review_interval", -1)

# Authorisation Check

allowed_facilities = get_home_facility_queryset(self.context["request"].user)
user = self.context["request"].user
allowed_facilities = get_home_facility_queryset(user)
if not allowed_facilities.filter(
id=self.validated_data["patient"].facility.id
id=self.validated_data["patient"].facility_id
).exists():
raise ValidationError(
{"facility": "Consultation creates are only allowed in home facility"}
)

# End Authorisation Checks

if validated_data["patient"].last_consultation:
with (
Lock(self._lock_key(validated_data["patient"].id)),
transaction.atomic(),
):
patient = validated_data["patient"]
if patient.last_consultation:
if patient.last_consultation.assigned_to == user:
raise ValidationError(
{
"Permission Denied": "Only Facility Staff can create consultation for a Patient"
},
)

if not patient.last_consultation.discharge_date:
raise ValidationError(
{"consultation": "Exists please Edit Existing Consultation"}
)

if "is_kasp" in validated_data:
if validated_data["is_kasp"]:
validated_data["kasp_enabled_date"] = now()

# Coercing facility as the patient's facility
validated_data["facility_id"] = patient.facility_id

consultation: PatientConsultation = super().create(validated_data)
consultation.created_by = user
consultation.last_edited_by = user
consultation.previous_consultation = patient.last_consultation
last_consultation = patient.last_consultation
if (
self.context["request"].user
== validated_data["patient"].last_consultation.assigned_to
last_consultation
and consultation.suggestion == SuggestionChoices.A
and last_consultation.suggestion == SuggestionChoices.A
and last_consultation.discharge_date
and last_consultation.discharge_date + timedelta(days=30)
> consultation.encounter_date
):
raise ValidationError(
{
"Permission Denied": "Only Facility Staff can create consultation for a Patient"
},
)
consultation.is_readmission = True

diagnosis = ConsultationDiagnosis.objects.bulk_create(
[
ConsultationDiagnosis(
consultation=consultation,
diagnosis_id=obj["diagnosis"].id,
is_principal=obj["is_principal"],
verification_status=obj["verification_status"],
created_by=user,
)
for obj in create_diagnosis
]
)

if validated_data["patient"].last_consultation:
if not validated_data["patient"].last_consultation.discharge_date:
raise ValidationError(
{"consultation": "Exists please Edit Existing Consultation"}
symptoms = EncounterSymptom.objects.bulk_create(
EncounterSymptom(
consultation=consultation,
symptom=obj.get("symptom"),
onset_date=obj.get("onset_date"),
cure_date=obj.get("cure_date"),
clinical_impression_status=obj.get("clinical_impression_status"),
other_symptom=obj.get("other_symptom") or "",
created_by=user,
)
for obj in create_symptoms
)

if "is_kasp" in validated_data:
if validated_data["is_kasp"]:
validated_data["kasp_enabled_date"] = localtime(now())

bed = validated_data.pop("bed", None)

validated_data["facility_id"] = validated_data[
"patient"
].facility_id # Coercing facility as the patient's facility
consultation = super().create(validated_data)
consultation.created_by = self.context["request"].user
consultation.last_edited_by = self.context["request"].user
patient = consultation.patient
consultation.previous_consultation = patient.last_consultation
last_consultation = patient.last_consultation
if (
last_consultation
and consultation.suggestion == SuggestionChoices.A
and last_consultation.suggestion == SuggestionChoices.A
and last_consultation.discharge_date
and last_consultation.discharge_date + timedelta(days=30)
> consultation.encounter_date
):
consultation.is_readmission = True
consultation.save()

diagnosis = ConsultationDiagnosis.objects.bulk_create(
[
ConsultationDiagnosis(
bed = validated_data.pop("bed", None)
if bed and consultation.suggestion == SuggestionChoices.A:
consultation_bed = ConsultationBed(
bed=bed,
consultation=consultation,
diagnosis_id=obj["diagnosis"].id,
is_principal=obj["is_principal"],
verification_status=obj["verification_status"],
created_by=self.context["request"].user,
start_date=consultation.created_date,
)
for obj in create_diagnosis
]
)
consultation_bed.save()
consultation.current_bed = consultation_bed

symptoms = EncounterSymptom.objects.bulk_create(
EncounterSymptom(
consultation=consultation,
symptom=obj.get("symptom"),
onset_date=obj.get("onset_date"),
cure_date=obj.get("cure_date"),
clinical_impression_status=obj.get("clinical_impression_status"),
other_symptom=obj.get("other_symptom") or "",
created_by=self.context["request"].user,
)
for obj in create_symptoms
)
if consultation.suggestion == SuggestionChoices.OP:
consultation.discharge_date = now()
patient.is_active = False
patient.allow_transfer = True
else:
patient.is_active = True
patient.last_consultation = consultation

if bed and consultation.suggestion == SuggestionChoices.A:
consultation_bed = ConsultationBed(
bed=bed,
consultation=consultation,
start_date=consultation.created_date,
)
consultation_bed.save()
consultation.current_bed = consultation_bed
consultation.save(update_fields=["current_bed"])
if action != -1:
patient.action = action

if consultation.suggestion == SuggestionChoices.OP:
consultation.discharge_date = localtime(now())
consultation.save()
patient.is_active = False
patient.allow_transfer = True
else:
patient.is_active = True
patient.last_consultation = consultation

if action != -1:
patient.action = action
consultation.review_interval = review_interval
if review_interval > 0:
patient.review_time = localtime(now()) + timedelta(minutes=review_interval)
else:
patient.review_time = None
if review_interval > 0:
patient.review_time = now() + timedelta(minutes=review_interval)
else:
patient.review_time = None

patient.save()
NotificationGenerator(
event=Notification.Event.PATIENT_CONSULTATION_CREATED,
caused_by=self.context["request"].user,
caused_object=consultation,
facility=patient.facility,
).generate()
consultation.save()
patient.save()

create_consultation_events(
consultation.id,
(consultation, *diagnosis, *symptoms),
consultation.created_by.id,
consultation.created_date,
)
create_consultation_events(
consultation.id,
(consultation, *diagnosis, *symptoms),
consultation.created_by.id,
consultation.created_date,
)

if consultation.assigned_to:
NotificationGenerator(
event=Notification.Event.PATIENT_CONSULTATION_ASSIGNMENT,
caused_by=self.context["request"].user,
event=Notification.Event.PATIENT_CONSULTATION_CREATED,
caused_by=user,
caused_object=consultation,
facility=consultation.patient.facility,
notification_mediums=[
Notification.Medium.SYSTEM,
Notification.Medium.WHATSAPP,
],
facility=patient.facility,
).generate()

return consultation
if consultation.assigned_to:
NotificationGenerator(
event=Notification.Event.PATIENT_CONSULTATION_ASSIGNMENT,
caused_by=user,
caused_object=consultation,
facility=consultation.patient.facility,
notification_mediums=[
Notification.Medium.SYSTEM,
Notification.Medium.WHATSAPP,
],
).generate()

return consultation

def validate_create_diagnoses(self, value):
# Reject if create_diagnoses is present for edits
Expand Down
5 changes: 5 additions & 0 deletions care/facility/api/viewsets/patient_consultation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from django.db import transaction
from django.db.models import Prefetch
from django.db.models.query_utils import Q
from django.shortcuts import get_object_or_404, render
Expand Down Expand Up @@ -109,6 +110,10 @@ def get_queryset(self):
applied_filters |= Q(facility=self.request.user.home_facility)
return self.queryset.filter(applied_filters)

@transaction.non_atomic_requests
def create(self, request, *args, **kwargs) -> Response:
return super().create(request, *args, **kwargs)

@extend_schema(tags=["consultation"])
@action(detail=True, methods=["POST"])
def discharge_patient(self, request, *args, **kwargs):
Expand Down
1 change: 0 additions & 1 deletion care/users/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ def test_last_active_filter(self):
response = self.client.get("/api/v1/users/?last_active_days=10")
self.assertEqual(response.status_code, status.HTTP_200_OK)
res_data_json = response.json()
print(res_data_json)
self.assertEqual(res_data_json["count"], 2)
self.assertIn(
self.user_2.username, {r["username"] for r in res_data_json["results"]}
Expand Down
30 changes: 30 additions & 0 deletions care/utils/lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from django.conf import settings
from django.core.cache import cache
from rest_framework.exceptions import APIException


class ObjectLocked(APIException):
status_code = 423
default_detail = "The resource you are trying to access is locked"
default_code = "object_locked"


class Lock:
def __init__(self, key, timeout=settings.LOCK_TIMEOUT):
self.key = f"lock:{key}"
self.timeout = timeout

def acquire(self):
if not cache.set(self.key, True, self.timeout, nx=True):
raise ObjectLocked()

def release(self):
return cache.delete(self.key)

def __enter__(self):
self.acquire()
return self

def __exit__(self, exc_type, exc_value, traceback):
self.release()
return False
2 changes: 1 addition & 1 deletion care/utils/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, decorated):
super().__init__(
CACHES={
"default": {
"BACKEND": "django.core.cache.backends.locmem.LocMemCache",
"BACKEND": "config.caches.LocMemCache",
"LOCATION": f"care-test-{uuid.uuid4()}",
}
},
Expand Down
16 changes: 16 additions & 0 deletions config/caches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from django.core.cache.backends import dummy, locmem
from django.core.cache.backends.base import DEFAULT_TIMEOUT


class DummyCache(dummy.DummyCache):
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, nx=None):
super().set(key, value, timeout, version)
# mimic the behavior of django_redis with setnx, for tests
return True


class LocMemCache(locmem.LocMemCache):
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, nx=None):
super().set(key, value, timeout, version)
# mimic the behavior of django_redis with setnx, for tests
return True
3 changes: 3 additions & 0 deletions config/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@
DATABASES["default"]["CONN_MAX_AGE"] = env.int("CONN_MAX_AGE", default=0)
DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"

# timeout for setnx lock
LOCK_TIMEOUT = env.int("LOCK_TIMEOUT", default=32)

REDIS_URL = env("REDIS_URL", default="redis://localhost:6379")

# CACHES
Expand Down
2 changes: 1 addition & 1 deletion config/settings/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
# test in peace
CACHES = {
"default": {
"BACKEND": "django.core.cache.backends.dummy.DummyCache",
"BACKEND": "config.caches.DummyCache",
}
}
# for testing retelimit use override_settings decorator
Expand Down

0 comments on commit 72f9e9a

Please sign in to comment.