Skip to content

Commit

Permalink
fix: Raise 500 exception instead of returning an error on detect AP…
Browse files Browse the repository at this point in the history
…I call (#863)

fix: Raise 422 exception instead of returning an error on `detect` API call
  • Loading branch information
ramyma authored May 4, 2023
1 parent 23c0c80 commit 5d387ab
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 6 deletions.
9 changes: 7 additions & 2 deletions scripts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
from fastapi import FastAPI, Body
from fastapi.exceptions import HTTPException
from PIL import Image
import copy
import pydantic
Expand Down Expand Up @@ -205,9 +206,13 @@ async def detect(
controlnet_module = global_state.reverse_preprocessor_aliases.get(controlnet_module, controlnet_module)

if controlnet_module not in global_state.cn_preprocessor_modules:
return {"images": [], "info": "Module not available"}
raise HTTPException(
status_code=422, detail="Module not available")


if len(controlnet_input_images) == 0:
return {"images": [], "info": "No image selected"}
raise HTTPException(
status_code=422, detail="No image selected")

print(f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module.")

Expand Down
2 changes: 1 addition & 1 deletion scripts/controlnet_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version_flag = 'v1.1.133'
version_flag = 'v1.1.134'
print(f'ControlNet {version_flag}')
# A smart trick to know if user has updated as well as if user has restarted terminal.
# Note that in "controlnet.py" we do NOT use "importlib.reload" to reload this "controlnet_version.py"
Expand Down
12 changes: 9 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os, sys, cv2
import os
import sys
import cv2
from base64 import b64encode

import requests

BASE_URL = "http://localhost:7860"


def setup_test_env():
ext_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
if ext_root not in sys.path:
Expand All @@ -19,7 +22,7 @@ def readImage(path):


def get_model():
r = requests.get("http://localhost:7860/controlnet/model_list")
r = requests.get(BASE_URL+"/controlnet/model_list")
result = r.json()
if "model_list" in result:
result = result["model_list"]
Expand All @@ -31,4 +34,7 @@ def get_model():

def get_modules():
return requests.get(f"{BASE_URL}/controlnet/module_list").json().get('module_list', [])



def detect(json):
return requests.post(BASE_URL+"/controlnet/detect", json=json)
41 changes: 41 additions & 0 deletions tests/web_api/detect_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import requests
import unittest
import importlib
utils = importlib.import_module(
'extensions.sd-webui-controlnet.tests.utils', 'utils')
utils.setup_test_env()


class TestDetectEndpointWorking(unittest.TestCase):
def setUp(self):
self.base_detect_args = {
"controlnet_module": "canny",
"controlnet_input_images": [utils.readImage("test/test_files/img2img_basic.png")],
"controlnet_processor_res": 512,
"controlnet_threshold_a": 0,
"controlnet_threshold_b": 0,
}

def test_detect_with_invalid_module_performed(self):
detect_args = self.base_detect_args.copy()
detect_args.update({
"controlnet_module": "INVALID",
})
self.assertEqual(utils.detect(detect_args).status_code, 422)

def test_detect_with_no_input_images_performed(self):
detect_args = self.base_detect_args.copy()
detect_args.update({
"controlnet_input_images": [],
})
self.assertEqual(utils.detect(detect_args).status_code, 422)

def test_detect_with_valid_args_performed(self):
detect_args = self.base_detect_args
response = utils.detect(detect_args)

self.assertEqual(response.status_code, 200)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5d387ab

Please sign in to comment.