diff --git a/care/facility/api/serializers/prescription.py b/care/facility/api/serializers/prescription.py index 71cd5c354..db24d83d1 100644 --- a/care/facility/api/serializers/prescription.py +++ b/care/facility/api/serializers/prescription.py @@ -75,6 +75,18 @@ class Meta: class PrescriptionSerializer(serializers.ModelSerializer): + ERROR_MESSAGES = { + "prn_indicator": "Indicator should be set for PRN prescriptions.", + "titrated_target": "Target dosage should be set for titrated prescriptions.", + "frequency": "Frequency should be set for prescriptions.", + "invalid_dosage_format": "Invalid dosage format. Expected 'number unit' but got '{}'", + "negative_dosage": "Dosage value cannot be negative: {}", + } + # Class-level constants for field names + PRN_FIELDS = {"indicator"} + TITRATED_FIELDS = {"target_dosage"} + STANDARD_FIELDS = {"frequency", "days"} + id = serializers.UUIDField(source="external_id", read_only=True) prescribed_by = UserBaseMinimumSerializer(read_only=True) last_administration = MedicineAdministrationSerializer(read_only=True) @@ -97,63 +109,135 @@ class Meta: "is_migrated", ) - def validate(self, attrs): + def _remove_irrelevant_fields(self, attrs: dict, keep_fields: set[str]) -> None: + """Remove fields not relevant for the current dosage type""" + all_fields = self.PRN_FIELDS | self.TITRATED_FIELDS | self.STANDARD_FIELDS + for field in all_fields - keep_fields: + attrs.pop(field, None) + + def validate_medicine(self, attrs): + """Validate the medicine field and check for duplicate prescriptions.""" if "medicine" in attrs: + medicine_id = attrs["medicine"] attrs["medicine"] = get_object_or_404( - MedibaseMedicine, external_id=attrs["medicine"] + MedibaseMedicine, external_id=medicine_id ) - if ( - not self.instance - and Prescription.objects.filter( - consultation__external_id=self.context["request"].parser_context[ - "kwargs" - ]["consultation_external_id"], - medicine=attrs["medicine"], - discontinued=False, - ).exists() - ): + # Check for existing prescription + if ( + not self.instance + and Prescription.objects.filter( + consultation__external_id=self.context["request"].parser_context[ + "kwargs" + ]["consultation_external_id"], + medicine=attrs["medicine"], + discontinued=False, + ).exists() + ): + raise serializers.ValidationError( + { + "medicine": ( + "This medicine is already prescribed to this patient. " + "Discontinue the existing prescription to prescribe again." + ) + } + ) + + def _validate_max_dosage_presence(self, base_dosage: str, max_dosage: str) -> None: + if max_dosage and not base_dosage: + raise serializers.ValidationError( + {"max_dosage": "Max dosage cannot be set without base dosage"} + ) + + def _validate_dosage_units(self, base_unit: str, max_unit: str) -> None: + if base_unit != max_unit: raise serializers.ValidationError( { - "medicine": ( - "This medicine is already prescribed to this patient. " - "Please discontinue the existing prescription to prescribe again." - ) + "max_dosage": f"Max dosage units ({max_unit}) must match base dosage units ({base_unit})." } ) - if not attrs.get("base_dosage"): + def validate_dosage(self, attrs): + """Validate base and max dosage.""" + base_dosage = attrs.get("base_dosage") + max_dosage = attrs.get("max_dosage") + + if not base_dosage: raise serializers.ValidationError( - {"base_dosage": "Base dosage is required."} + {"base_dosage": "Base dosage is required"} ) - if attrs.get("dosage_type") == PrescriptionDosageType.PRN: + self._validate_max_dosage_presence(base_dosage, max_dosage) + + if max_dosage: + try: + base_dosage_value, base_unit = self.parse_dosage(base_dosage) + max_dosage_value, max_unit = self.parse_dosage(max_dosage) + + self._validate_dosage_units(base_unit, max_unit) + + if max_dosage_value < base_dosage_value: + raise serializers.ValidationError( + { + "max_dosage": "Max dosage in 24 hours should be greater than or equal to base dosage." + } + ) + except ValueError as e: + raise serializers.ValidationError( + { + "max_dosage": "Invalid dosage format. Expected format: 'number unit' (e.g., '500 mg')" + } + ) from e + + def validate_dosage_type_specific(self, attrs): + """Validate fields specific to dosage types.""" + dosage_type = attrs.get("dosage_type") + + if dosage_type == PrescriptionDosageType.PRN: if not attrs.get("indicator"): raise serializers.ValidationError( - {"indicator": "Indicator should be set for PRN prescriptions."} + {"indicator": self.ERROR_MESSAGES["prn_indicator"]} ) - attrs.pop("frequency", None) - attrs.pop("days", None) + # Remove irrelevant fields + self._remove_irrelevant_fields(attrs, self.PRN_FIELDS) + + elif dosage_type == PrescriptionDosageType.TITRATED: + if not attrs.get("target_dosage"): + raise serializers.ValidationError( + { + "target_dosage": "Target dosage should be set for titrated prescriptions." + } + ) + # Remove irrelevant fields + self._remove_irrelevant_fields(attrs, self.TITRATED_FIELDS) + else: if not attrs.get("frequency"): raise serializers.ValidationError( {"frequency": "Frequency should be set for prescriptions."} ) + # Remove irrelevant fields + self._remove_irrelevant_fields(attrs, self.STANDARD_FIELDS) + + # If it's not PRN or TITRATED, ensure standard fields are respected attrs.pop("indicator", None) attrs.pop("max_dosage", None) attrs.pop("min_hours_between_doses", None) - - if attrs.get("dosage_type") == PrescriptionDosageType.TITRATED: - if not attrs.get("target_dosage"): - raise serializers.ValidationError( - { - "target_dosage": "Target dosage should be set for titrated prescriptions." - } - ) - else: - attrs.pop("target_dosage", None) - - return super().validate(attrs) + attrs.pop("target_dosage", None) + + DOSAGE_PARTS_REQUIRED = 2 # Define a constant for the required parts in dosage + + def parse_dosage(self, dosage): + """Parse the dosage into value and unit parts.""" + parts = dosage.split(" ", maxsplit=1) + if len(parts) != self.DOSAGE_PARTS_REQUIRED: + error_message = self.ERROR_MESSAGES["invalid_dosage_format"].format(dosage) + raise ValueError(error_message) + value = float(parts[0]) + if value < 0: + error_message = self.ERROR_MESSAGES["negative_dosage"].format(value) + raise ValueError(error_message) + return value, parts[1] def create(self, validated_data): if validated_data["consultation"].discharge_date: diff --git a/care/facility/tests/test_prescriptions_api.py b/care/facility/tests/test_prescriptions_api.py index 3168aeb15..a235d9aca 100644 --- a/care/facility/tests/test_prescriptions_api.py +++ b/care/facility/tests/test_prescriptions_api.py @@ -285,3 +285,32 @@ def test_medicine_filter_for_prescription(self): self.assertEqual( prescription["medicine_object"]["name"], self.medicine.name ) + + def test_max_dosage_greater_than_base_dosage(self): + data = self.prescription_data(base_dosage="500 mg", max_dosage="1000 mg") + response = self.client.post( + f"/api/v1/consultation/{self.consultation.external_id}/prescriptions/", + data, + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + def test_max_dosage_equal_to_base_dosage(self): + data = self.prescription_data(base_dosage="500 mg", max_dosage="500 mg") + response = self.client.post( + f"/api/v1/consultation/{self.consultation.external_id}/prescriptions/", + data, + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + def test_max_dosage_less_than_base_dosage(self): + data = self.prescription_data(base_dosage="500 mg", max_dosage="400 mg") + response = self.client.post( + f"/api/v1/consultation/{self.consultation.external_id}/prescriptions/", + data, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn("max_dosage", response.data) + self.assertEqual( + response.data["max_dosage"][0], + "Max dosage in 24 hours should be greater than or equal to base dosage.", + )