Skip to content

Commit

Permalink
Allow steps to be < 10 (for SDXL Turbo)
Browse files Browse the repository at this point in the history
  • Loading branch information
benrugg committed Dec 1, 2023
1 parent 878e955 commit 7387975
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 60 deletions.
3 changes: 2 additions & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": "AI Render - Stable Diffusion in Blender",
"description": "Create amazing images using Stable Diffusion AI",
"author": "Ben Rugg",
"version": (1, 0, 0),
"version": (1, 0, 1),
"blender": (3, 0, 0),
"location": "Render Properties > AI Render",
"warning": "",
Expand All @@ -14,6 +14,7 @@

if "bpy" in locals():
import imp

imp.reload(addon_updater_ops)
imp.reload(analytics)
imp.reload(config)
Expand Down
21 changes: 11 additions & 10 deletions properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def ensure_sampler(context):
def ensure_upscaler_model(context):
# """Ensure that the upscale model is set to a valid value"""
scene = context.scene
if utils.get_active_backend().is_upscaler_model_list_loaded(context) and not scene.air_props.upscaler_model:
if (
utils.get_active_backend().is_upscaler_model_list_loaded(context)
and not scene.air_props.upscaler_model
):
scene.air_props.upscaler_model = get_default_upscaler_model()


Expand Down Expand Up @@ -117,23 +120,23 @@ class AIRProperties(bpy.types.PropertyGroup):
steps: bpy.props.IntProperty(
name="Steps",
default=30,
soft_min=10,
soft_min=1,
soft_max=50,
min=10,
min=1,
max=150,
description="How long to process the image. Values in the range of 25-50 generally work well. Higher values take longer (and use more credits) and may or may not improve results",
description="How long to process the image. Values in the range of 20-40 generally work well. Higher values take longer (and use more credits) and may or may not improve results",
)
sd_model: bpy.props.EnumProperty(
name="Stable Diffusion Model",
default=120,
items=[
('stable-diffusion-xl-1024-v1-0', 'SDXL 1.0', '', 120),
("stable-diffusion-xl-1024-v1-0", "SDXL 1.0", "", 120),
],
description="The Stable Diffusion model to use. SDXL is comparable to Midjourney. Older versions have now been removed, but newer versions may be added in the future",
)
sampler: bpy.props.EnumProperty(
name="Sampler",
default=120, # maps to DPM++ 2M, which is a good, fast sampler
default=120, # maps to DPM++ 2M, which is a good, fast sampler
items=get_available_samplers,
description="Which sampler method to use",
)
Expand Down Expand Up @@ -196,7 +199,7 @@ class AIRProperties(bpy.types.PropertyGroup):
max=8.0,
precision=1,
step=10,
description="The factor to upscale the image by. The resulting image will be its original size times this factor"
description="The factor to upscale the image by. The resulting image will be its original size times this factor",
)
do_upscale_automatically: bpy.props.BoolProperty(
name="Upscale Automatically",
Expand Down Expand Up @@ -342,9 +345,7 @@ class AIRProperties(bpy.types.PropertyGroup):
)


classes = [
AIRProperties
]
classes = [AIRProperties]


def register():
Expand Down
150 changes: 101 additions & 49 deletions sd_backends/stability_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

# CORE FUNCTIONS:

def generate(params, img_file, filename_prefix, props):

def generate(params, img_file, filename_prefix, props):
# validate the params, specifically for the Stability API
if not validate_params(params, props):
return False
Expand All @@ -27,16 +27,25 @@ def generate(params, img_file, filename_prefix, props):

# prepare the file input
files = {
'init_image': img_file,
"init_image": img_file,
}

# send the API request
try:
response = requests.post(api_url, headers=headers, files=files, data=mapped_params, timeout=request_timeout())
response = requests.post(
api_url,
headers=headers,
files=files,
data=mapped_params,
timeout=request_timeout(),
)
img_file.close()
except requests.exceptions.ReadTimeout:
img_file.close()
return operators.handle_error(f"The server timed out. Try again in a moment, or get help. [Get help with timeouts]({config.HELP_WITH_TIMEOUTS_URL})", "timeout")
return operators.handle_error(
f"The server timed out. Try again in a moment, or get help. [Get help with timeouts]({config.HELP_WITH_TIMEOUTS_URL})",
"timeout",
)

# print log info for debugging
# debug_log(response)
Expand All @@ -49,7 +58,6 @@ def generate(params, img_file, filename_prefix, props):


def upscale(img_file, filename_prefix, props):

# create the headers
headers = create_headers()

Expand All @@ -58,21 +66,24 @@ def upscale(img_file, filename_prefix, props):

# prepare the file input
files = {
'image': img_file,
"image": img_file,
}

# prepare the params
data = {
'width': utils.sanitized_upscaled_width(max_upscaled_image_size())
}
data = {"width": utils.sanitized_upscaled_width(max_upscaled_image_size())}

# send the API request
try:
response = requests.post(api_url, headers=headers, files=files, data=data, timeout=request_timeout())
response = requests.post(
api_url, headers=headers, files=files, data=data, timeout=request_timeout()
)
img_file.close()
except requests.exceptions.ReadTimeout:
img_file.close()
return operators.handle_error(f"The server timed out during upscaling. Try again in a moment, or turn off upscaling.", "timeout")
return operators.handle_error(
f"The server timed out during upscaling. Try again in a moment, or turn off upscaling.",
"timeout",
)

# print log info for debugging
# debug_log(response)
Expand All @@ -89,31 +100,36 @@ def handle_success(response, filename_prefix):
data = response.json()
output_file = utils.create_temp_file(filename_prefix + "-")
except:
return operators.handle_error(f"Couldn't create a temp file to save image", "temp_file")
return operators.handle_error(
f"Couldn't create a temp file to save image", "temp_file"
)

try:
for i, image in enumerate(data["artifacts"]):
with open(output_file, 'wb') as file:
with open(output_file, "wb") as file:
file.write(base64.b64decode(image["base64"]))

return output_file
except:
return operators.handle_error(f"DreamStudio returned an unexpected response", "unexpected_response")
return operators.handle_error(
f"DreamStudio returned an unexpected response", "unexpected_response"
)


def handle_error(response):
import json
error_key = ''

error_key = ""

try:
# convert the response to JSON (hopefully)
response_obj = response.json()

# get the message key from the response, if it exists
message = response_obj.get('message', str(response.content))
message = response_obj.get("message", str(response.content))

# handle the different types of errors
if response_obj.get('timeout', False):
if response_obj.get("timeout", False):
error_message = f"The server timed out. Try again in a moment, or get help. [Get help with timeouts]({config.HELP_WITH_TIMEOUTS_URL})"
error_key = "timeout"
else:
Expand All @@ -127,11 +143,12 @@ def handle_error(response):

# PRIVATE SUPPORT FUNCTIONS:


def create_headers():
return {
"User-Agent": f"Blender/{bpy.app.version_string}",
"Accept": "application/json",
"Authorization": f"Bearer {utils.get_dream_studio_api_key()}"
"Authorization": f"Bearer {utils.get_dream_studio_api_key()}",
}


Expand All @@ -158,38 +175,72 @@ def map_params(params):


def validate_params(params, props):
if props.sd_model.startswith('stable-diffusion-xl-1024'):
# the sdxl 1024 model only supports a few specific image sizes
if utils.are_sdxl_1024_dimensions_valid(params["width"], params["height"]):
return True
else:
return operators.handle_error(f"The SDXL model only supports these image sizes: {', '.join(utils.sdxl_1024_valid_dimensions)}. Please change your image size and try again.", "invalid_dimensions")
# validate the dimensions (the sdxl 1024 model only supports a few specific image sizes)
if props.sd_model.startswith(
"stable-diffusion-xl-1024"
) and not utils.are_sdxl_1024_dimensions_valid(params["width"], params["height"]):
return operators.handle_error(
f"The SDXL model only supports these image sizes: {', '.join(utils.sdxl_1024_valid_dimensions)}. Please change your image size and try again.",
"invalid_dimensions",
)
elif params["steps"] < 10:
return operators.handle_error(
"Steps must be set to at least 10.", "steps_too_small"
)
else:
return True


def parse_message_for_error(message):
if "\"Authorization\" is missing" in message:
if '"Authorization" is missing' in message:
return "Your DreamStudio API key is missing. Please enter it above.", "api_key"
elif "Incorrect API key" in message or "Unauthenticated" in message or "Unable to find corresponding account" in message:
return f"Your DreamStudio API key is incorrect. Please find it on the DreamStudio website, and re-enter it above. [DreamStudio website]({config.DREAM_STUDIO_URL})", "api_key"
elif (
"Incorrect API key" in message
or "Unauthenticated" in message
or "Unable to find corresponding account" in message
):
return (
f"Your DreamStudio API key is incorrect. Please find it on the DreamStudio website, and re-enter it above. [DreamStudio website]({config.DREAM_STUDIO_URL})",
"api_key",
)
elif "not have enough balance" in message:
return f"You don't have enough DreamStudio credits. Please purchase credits on the DreamStudio website or switch to a different backend in the AI Render add-on preferences. [DreamStudio website]({config.DREAM_STUDIO_URL})", "credits"
return (
f"You don't have enough DreamStudio credits. Please purchase credits on the DreamStudio website or switch to a different backend in the AI Render add-on preferences. [DreamStudio website]({config.DREAM_STUDIO_URL})",
"credits",
)
elif "invalid_prompts" in message:
return "Invalid prompt. Your prompt includes filtered words. Please change your prompt and try again.", "prompt"
return (
"Invalid prompt. Your prompt includes filtered words. Please change your prompt and try again.",
"prompt",
)
elif "image too large" in message:
return "Image size is too large. Please decrease width/height.", "dimensions_too_large"
return (
"Image size is too large. Please decrease width/height.",
"dimensions_too_large",
)
elif "invalid_height_or_width" in message:
return "Invalid width or height. They must be in the range 128-2048 in multiples of 64.", "invalid_dimensions"
return (
"Invalid width or height. They must be in the range 128-2048 in multiples of 64.",
"invalid_dimensions",
)
elif "body.sampler must be" in message:
return "Invalid sampler. Please choose a new Sampler under 'Advanced Options'.", "sampler"
return (
"Invalid sampler. Please choose a new Sampler under 'Advanced Options'.",
"sampler",
)
elif "body.cfg_scale must be" in message:
return "Invalid prompt strength. 'Prompt Strength' must be in the range 0-35.", "prompt_strength"
return (
"Invalid prompt strength. 'Prompt Strength' must be in the range 0-35.",
"prompt_strength",
)
elif "body.seed must be" in message:
return "Invalid seed value. Please choose a new 'Seed'.", "seed"
elif "body.steps must be" in message:
return "Invalid number of steps. 'Steps' must be in the range 10-150.", "steps"
return f"(Server Error) An error occurred in the Stability API. Full server response: {message}", "unknown_error"
return (
f"(Server Error) An error occurred in the Stability API. Full server response: {message}",
"unknown_error",
)


def debug_log(response):
Expand All @@ -208,31 +259,32 @@ def debug_log(response):

# PUBLIC SUPPORT FUNCTIONS:


def get_samplers():
# NOTE: Keep the number values (fourth item in the tuples) in sync with the other
# backends, like Automatic1111. These act like an internal unique ID for Blender
# to use when switching between the lists.
return [
('k_euler', 'Euler', '', 10),
('k_euler_ancestral', 'Euler a', '', 20),
('k_heun', 'Heun', '', 30),
('k_dpm_2', 'DPM2', '', 40),
('k_dpm_2_ancestral', 'DPM2 a', '', 50),
('k_lms', 'LMS', '', 60),
('K_DPMPP_2S_ANCESTRAL', 'DPM++ 2S a', '', 110),
('K_DPMPP_2M', 'DPM++ 2M', '', 120),
('ddim', 'DDIM', '', 210),
('ddpm', 'DDPM', '', 220),
("k_euler", "Euler", "", 10),
("k_euler_ancestral", "Euler a", "", 20),
("k_heun", "Heun", "", 30),
("k_dpm_2", "DPM2", "", 40),
("k_dpm_2_ancestral", "DPM2 a", "", 50),
("k_lms", "LMS", "", 60),
("K_DPMPP_2S_ANCESTRAL", "DPM++ 2S a", "", 110),
("K_DPMPP_2M", "DPM++ 2M", "", 120),
("ddim", "DDIM", "", 210),
("ddpm", "DDPM", "", 220),
]


def default_sampler():
return 'K_DPMPP_2M'
return "K_DPMPP_2M"


def get_upscaler_models(context):
return [
('esrgan-v1-x2plus', 'ESRGAN X2+', ''),
("esrgan-v1-x2plus", "ESRGAN X2+", ""),
]


Expand All @@ -241,15 +293,15 @@ def is_upscaler_model_list_loaded(context=None):


def default_upscaler_model():
return 'esrgan-v1-x2plus'
return "esrgan-v1-x2plus"


def request_timeout():
return 55


def get_image_format():
return 'PNG'
return "PNG"


def supports_negative_prompts():
Expand Down Expand Up @@ -289,4 +341,4 @@ def max_upscaled_image_size():


def is_using_sdxl_1024_model(props):
return props.sd_model.startswith('stable-diffusion-xl-1024')
return props.sd_model.startswith("stable-diffusion-xl-1024")

0 comments on commit 7387975

Please sign in to comment.