diff --git a/scripts/api.py b/scripts/api.py index d3b2f65b7..57c29d45a 100644 --- a/scripts/api.py +++ b/scripts/api.py @@ -1,12 +1,7 @@ -from typing import Union - import numpy as np from fastapi import FastAPI, Body from fastapi.exceptions import HTTPException from PIL import Image -import copy -import pydantic -import sys import gradio as gr @@ -29,155 +24,6 @@ def encode_np_to_base64(image): pil = Image.fromarray(image) return api.encode_pil_to_base64(pil) -cn_root_field_prefix = 'controlnet_' -cn_fields = { - "input_image": (str, Field(default="", title='ControlNet Input Image')), - "mask": (str, Field(default="", title='ControlNet Input Mask')), - "module": (str, Field(default="none", title='Controlnet Module')), - "model": (str, Field(default="None", title='Controlnet Model')), - "weight": (float, Field(default=1.0, title='Controlnet Weight')), - "resize_mode": (Union[int, str], Field(default="Crop and Resize", title='Controlnet Resize Mode')), - "lowvram": (bool, Field(default=False, title='Controlnet Low VRAM')), - "processor_res": (int, Field(default=64, title='Controlnet Processor Res')), - "threshold_a": (float, Field(default=64, title='Controlnet Threshold a')), - "threshold_b": (float, Field(default=64, title='Controlnet Threshold b')), - "guidance": (float, Field(default=1.0, title='ControlNet Guidance Strength')), - "guidance_start": (float, Field(0.0, title='ControlNet Guidance Start')), - "guidance_end": (float, Field(1.0, title='ControlNet Guidance End')), - "guessmode": (bool, Field(default=True, title="Guess Mode")), - "pixel_perfect": (bool, Field(default=False, title="Pixel Perfect")) -} - -def get_deprecated_cn_field(field_name: str, field): - field_type, field = field - field = copy.copy(field) - field.default = None - field.extra['_deprecated'] = True - if field_name in ('input_image', 'mask'): - field_type = List[field_type] - return f'{cn_root_field_prefix}{field_name}', (field_type, field) - -def get_deprecated_field_default(field_name: str): - if field_name in ('input_image', 'mask'): - return [] - return cn_fields[field_name][-1].default - -ControlNetUnitRequest = pydantic.create_model('ControlNetUnitRequest', **cn_fields) - -def create_controlnet_request_model(p_api_class): - class RequestModel(p_api_class): - class Config(p_api_class.__config__): - @staticmethod - def schema_extra(schema: dict, _): - props = {} - for k, v in schema.get('properties', {}).items(): - if not v.get('_deprecated', False): - props[k] = v - if v.get('docs_default', None) is not None: - v['default'] = v['docs_default'] - if props: - schema['properties'] = props - - additional_fields = { - 'controlnet_units': (List[ControlNetUnitRequest], Field(default=[], docs_default=[ControlNetUnitRequest()], description="ControlNet Processing Units")), - **dict(get_deprecated_cn_field(k, v) for k, v in cn_fields.items()) - } - - return pydantic.create_model( - f'ControlNet{p_api_class.__name__}', - __base__=RequestModel, - **additional_fields) - -ControlNetTxt2ImgRequest = create_controlnet_request_model(StableDiffusionTxt2ImgProcessingAPI) -ControlNetImg2ImgRequest = create_controlnet_request_model(StableDiffusionImg2ImgProcessingAPI) - -class ApiHijack(api.Api): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.add_api_route("/controlnet/txt2img", self.controlnet_txt2img, methods=["POST"], response_model=TextToImageResponse) - self.add_api_route("/controlnet/img2img", self.controlnet_img2img, methods=["POST"], response_model=ImageToImageResponse) - - def controlnet_txt2img(self, txt2img_request: ControlNetTxt2ImgRequest): - return self.controlnet_any2img( - any2img_request=txt2img_request, - original_callback=ApiHijack.text2imgapi, - is_img2img=False, - ) - - def controlnet_img2img(self, img2img_request: ControlNetImg2ImgRequest): - return self.controlnet_any2img( - any2img_request=img2img_request, - original_callback=ApiHijack.img2imgapi, - is_img2img=True, - ) - - def controlnet_any2img(self, any2img_request, original_callback, is_img2img): - warn_deprecated_route(is_img2img) - any2img_request = nest_deprecated_cn_fields(any2img_request) - alwayson_scripts = dict(any2img_request.alwayson_scripts) - any2img_request.alwayson_scripts.update({'ControlNet': {'args': [to_api_cn_unit(unit) for unit in any2img_request.controlnet_units]}}) - controlnet_units = any2img_request.controlnet_units - delattr(any2img_request, 'controlnet_units') - result = original_callback(self, any2img_request) - result.parameters['controlnet_units'] = controlnet_units - result.parameters['alwayson_scripts'] = alwayson_scripts - return result - -api.Api = ApiHijack - -def nest_deprecated_cn_fields(any2img_request): - deprecated_cn_fields = {k: v for k, v in vars(any2img_request).items() - if k.startswith(cn_root_field_prefix) and k != 'controlnet_units'} - - any2img_request = copy.copy(any2img_request) - for k in deprecated_cn_fields.keys(): - delattr(any2img_request, k) - - if all(v is None for v in deprecated_cn_fields.values()): - return any2img_request - - deprecated_cn_fields = {k[len(cn_root_field_prefix):]: v for k, v in deprecated_cn_fields.items()} - for k, v in deprecated_cn_fields.items(): - if v is None: - deprecated_cn_fields[k] = get_deprecated_field_default(k) - - for k in ('input_image', 'mask'): - deprecated_cn_fields[k] = deprecated_cn_fields[k][0] if deprecated_cn_fields[k] else "" - - any2img_request.controlnet_units.insert(0, ControlNetUnitRequest(**deprecated_cn_fields)) - return any2img_request - -def to_api_cn_unit(unit_request: ControlNetUnitRequest) -> external_code.ControlNetUnit: - input_image = external_code.to_base64_nparray(unit_request.input_image) if unit_request.input_image else None - mask = external_code.to_base64_nparray(unit_request.mask) if unit_request.mask else None - if input_image is not None and mask is not None: - input_image = (input_image, mask) - - if unit_request.guidance < 1.0: - unit_request.guidance_end = unit_request.guidance - - return external_code.ControlNetUnit( - module=unit_request.module, - model=unit_request.model, - weight=unit_request.weight, - image=input_image, - resize_mode=unit_request.resize_mode, - low_vram=unit_request.lowvram, - processor_res=unit_request.processor_res, - threshold_a=unit_request.threshold_a, - threshold_b=unit_request.threshold_b, - guidance_start=unit_request.guidance_start, - guidance_end=unit_request.guidance_end, - guess_mode=unit_request.guessmode, - pixel_perfect=unit_request.pixel_perfect - ) - -def warn_deprecated_route(is_img2img): - route = 'img2img' if is_img2img else 'txt2img' - warning_prefix = '[ControlNet] warning: ' - print(f"{warning_prefix}using deprecated '/controlnet/{route}' route", file=sys.stderr) - print(f"{warning_prefix}consider using the '/sdapi/v1/{route}' route with the 'alwayson_scripts' json property instead", file=sys.stderr) - def controlnet_api(_: gr.Blocks, app: FastAPI): @app.get("/controlnet/version") async def version(): @@ -208,12 +54,11 @@ async def detect( if controlnet_module not in global_state.cn_preprocessor_modules: raise HTTPException( status_code=422, detail="Module not available") - - + if len(controlnet_input_images) == 0: raise HTTPException( status_code=422, detail="No image selected") - + print(f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module.") results = [] diff --git a/scripts/controlnet_version.py b/scripts/controlnet_version.py index 2e9db5ee0..d594f15a3 100644 --- a/scripts/controlnet_version.py +++ b/scripts/controlnet_version.py @@ -1,4 +1,4 @@ -version_flag = 'v1.1.134' +version_flag = 'v1.1.135' print(f'ControlNet {version_flag}') # A smart trick to know if user has updated as well as if user has restarted terminal. # Note that in "controlnet.py" we do NOT use "importlib.reload" to reload this "controlnet_version.py" diff --git a/scripts/external_code.py b/scripts/external_code.py index 65e413999..6c71f17e3 100644 --- a/scripts/external_code.py +++ b/scripts/external_code.py @@ -9,7 +9,7 @@ def get_api_version() -> int: - return 1 + return 2 class ControlMode(Enum):