Skip to content

Commit

Permalink
Remove deprecated api (#1154)
Browse files Browse the repository at this point in the history
* remove deprecated api

* version

* api version
  • Loading branch information
ljleb authored May 5, 2023
1 parent 5d387ab commit 11d33e1
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 159 deletions.
159 changes: 2 additions & 157 deletions scripts/api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from typing import Union

import numpy as np
from fastapi import FastAPI, Body
from fastapi.exceptions import HTTPException
from PIL import Image
import copy
import pydantic
import sys

import gradio as gr

Expand All @@ -29,155 +24,6 @@ def encode_np_to_base64(image):
pil = Image.fromarray(image)
return api.encode_pil_to_base64(pil)

cn_root_field_prefix = 'controlnet_'
cn_fields = {
"input_image": (str, Field(default="", title='ControlNet Input Image')),
"mask": (str, Field(default="", title='ControlNet Input Mask')),
"module": (str, Field(default="none", title='Controlnet Module')),
"model": (str, Field(default="None", title='Controlnet Model')),
"weight": (float, Field(default=1.0, title='Controlnet Weight')),
"resize_mode": (Union[int, str], Field(default="Crop and Resize", title='Controlnet Resize Mode')),
"lowvram": (bool, Field(default=False, title='Controlnet Low VRAM')),
"processor_res": (int, Field(default=64, title='Controlnet Processor Res')),
"threshold_a": (float, Field(default=64, title='Controlnet Threshold a')),
"threshold_b": (float, Field(default=64, title='Controlnet Threshold b')),
"guidance": (float, Field(default=1.0, title='ControlNet Guidance Strength')),
"guidance_start": (float, Field(0.0, title='ControlNet Guidance Start')),
"guidance_end": (float, Field(1.0, title='ControlNet Guidance End')),
"guessmode": (bool, Field(default=True, title="Guess Mode")),
"pixel_perfect": (bool, Field(default=False, title="Pixel Perfect"))
}

def get_deprecated_cn_field(field_name: str, field):
field_type, field = field
field = copy.copy(field)
field.default = None
field.extra['_deprecated'] = True
if field_name in ('input_image', 'mask'):
field_type = List[field_type]
return f'{cn_root_field_prefix}{field_name}', (field_type, field)

def get_deprecated_field_default(field_name: str):
if field_name in ('input_image', 'mask'):
return []
return cn_fields[field_name][-1].default

ControlNetUnitRequest = pydantic.create_model('ControlNetUnitRequest', **cn_fields)

def create_controlnet_request_model(p_api_class):
class RequestModel(p_api_class):
class Config(p_api_class.__config__):
@staticmethod
def schema_extra(schema: dict, _):
props = {}
for k, v in schema.get('properties', {}).items():
if not v.get('_deprecated', False):
props[k] = v
if v.get('docs_default', None) is not None:
v['default'] = v['docs_default']
if props:
schema['properties'] = props

additional_fields = {
'controlnet_units': (List[ControlNetUnitRequest], Field(default=[], docs_default=[ControlNetUnitRequest()], description="ControlNet Processing Units")),
**dict(get_deprecated_cn_field(k, v) for k, v in cn_fields.items())
}

return pydantic.create_model(
f'ControlNet{p_api_class.__name__}',
__base__=RequestModel,
**additional_fields)

ControlNetTxt2ImgRequest = create_controlnet_request_model(StableDiffusionTxt2ImgProcessingAPI)
ControlNetImg2ImgRequest = create_controlnet_request_model(StableDiffusionImg2ImgProcessingAPI)

class ApiHijack(api.Api):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_api_route("/controlnet/txt2img", self.controlnet_txt2img, methods=["POST"], response_model=TextToImageResponse)
self.add_api_route("/controlnet/img2img", self.controlnet_img2img, methods=["POST"], response_model=ImageToImageResponse)

def controlnet_txt2img(self, txt2img_request: ControlNetTxt2ImgRequest):
return self.controlnet_any2img(
any2img_request=txt2img_request,
original_callback=ApiHijack.text2imgapi,
is_img2img=False,
)

def controlnet_img2img(self, img2img_request: ControlNetImg2ImgRequest):
return self.controlnet_any2img(
any2img_request=img2img_request,
original_callback=ApiHijack.img2imgapi,
is_img2img=True,
)

def controlnet_any2img(self, any2img_request, original_callback, is_img2img):
warn_deprecated_route(is_img2img)
any2img_request = nest_deprecated_cn_fields(any2img_request)
alwayson_scripts = dict(any2img_request.alwayson_scripts)
any2img_request.alwayson_scripts.update({'ControlNet': {'args': [to_api_cn_unit(unit) for unit in any2img_request.controlnet_units]}})
controlnet_units = any2img_request.controlnet_units
delattr(any2img_request, 'controlnet_units')
result = original_callback(self, any2img_request)
result.parameters['controlnet_units'] = controlnet_units
result.parameters['alwayson_scripts'] = alwayson_scripts
return result

api.Api = ApiHijack

def nest_deprecated_cn_fields(any2img_request):
deprecated_cn_fields = {k: v for k, v in vars(any2img_request).items()
if k.startswith(cn_root_field_prefix) and k != 'controlnet_units'}

any2img_request = copy.copy(any2img_request)
for k in deprecated_cn_fields.keys():
delattr(any2img_request, k)

if all(v is None for v in deprecated_cn_fields.values()):
return any2img_request

deprecated_cn_fields = {k[len(cn_root_field_prefix):]: v for k, v in deprecated_cn_fields.items()}
for k, v in deprecated_cn_fields.items():
if v is None:
deprecated_cn_fields[k] = get_deprecated_field_default(k)

for k in ('input_image', 'mask'):
deprecated_cn_fields[k] = deprecated_cn_fields[k][0] if deprecated_cn_fields[k] else ""

any2img_request.controlnet_units.insert(0, ControlNetUnitRequest(**deprecated_cn_fields))
return any2img_request

def to_api_cn_unit(unit_request: ControlNetUnitRequest) -> external_code.ControlNetUnit:
input_image = external_code.to_base64_nparray(unit_request.input_image) if unit_request.input_image else None
mask = external_code.to_base64_nparray(unit_request.mask) if unit_request.mask else None
if input_image is not None and mask is not None:
input_image = (input_image, mask)

if unit_request.guidance < 1.0:
unit_request.guidance_end = unit_request.guidance

return external_code.ControlNetUnit(
module=unit_request.module,
model=unit_request.model,
weight=unit_request.weight,
image=input_image,
resize_mode=unit_request.resize_mode,
low_vram=unit_request.lowvram,
processor_res=unit_request.processor_res,
threshold_a=unit_request.threshold_a,
threshold_b=unit_request.threshold_b,
guidance_start=unit_request.guidance_start,
guidance_end=unit_request.guidance_end,
guess_mode=unit_request.guessmode,
pixel_perfect=unit_request.pixel_perfect
)

def warn_deprecated_route(is_img2img):
route = 'img2img' if is_img2img else 'txt2img'
warning_prefix = '[ControlNet] warning: '
print(f"{warning_prefix}using deprecated '/controlnet/{route}' route", file=sys.stderr)
print(f"{warning_prefix}consider using the '/sdapi/v1/{route}' route with the 'alwayson_scripts' json property instead", file=sys.stderr)

def controlnet_api(_: gr.Blocks, app: FastAPI):
@app.get("/controlnet/version")
async def version():
Expand Down Expand Up @@ -208,12 +54,11 @@ async def detect(
if controlnet_module not in global_state.cn_preprocessor_modules:
raise HTTPException(
status_code=422, detail="Module not available")



if len(controlnet_input_images) == 0:
raise HTTPException(
status_code=422, detail="No image selected")

print(f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module.")

results = []
Expand Down
2 changes: 1 addition & 1 deletion scripts/controlnet_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version_flag = 'v1.1.134'
version_flag = 'v1.1.135'
print(f'ControlNet {version_flag}')
# A smart trick to know if user has updated as well as if user has restarted terminal.
# Note that in "controlnet.py" we do NOT use "importlib.reload" to reload this "controlnet_version.py"
Expand Down
2 changes: 1 addition & 1 deletion scripts/external_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def get_api_version() -> int:
return 1
return 2


class ControlMode(Enum):
Expand Down

0 comments on commit 11d33e1

Please sign in to comment.