Skip to content

Commit

Permalink
SD - Implement seed arrays for batch runs (#1690)
Browse files Browse the repository at this point in the history
* SD Scripts and UI tabs that support batch_count can now take a
string containing a JSON array, or a list of integers, as their seed
input.
* Each batch in a run will now take the seed specified at the
corresponding array index if one exists. If there is no seed at
that index, the seed value will be treated as -1 and a random
seed will be assigned at that position. If an integer rather than
a list or json array has been, everything works as before.
* UI seed input controls are now Textboxes with info lines about
the seed formats allowed.
* UI error handling updated to be more helpful if the seed input is
invalid.
  • Loading branch information
one-lithe-rune authored Jul 25, 2023
1 parent 453e465 commit 289f983
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 36 deletions.
1 change: 1 addition & 0 deletions apps/stable_diffusion/src/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,
parse_seed_input,
batch_seeds,
get_path_stem,
get_extended_name,
Expand Down
4 changes: 2 additions & 2 deletions apps/stable_diffusion/src/utils/stable_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def is_valid_file(arg):

p.add_argument(
"--seed",
type=int,
type=str,
default=-1,
help="The seed to use. -1 for a random one.",
help="The seed or list of seeds to use. -1 for a random one.",
)

p.add_argument(
Expand Down
51 changes: 40 additions & 11 deletions apps/stable_diffusion/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,28 +727,57 @@ def fetch_and_update_base_model_id(model_to_run, base_model=""):

# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed):
def sanitize_seed(seed: int | str):
seed = int(seed)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed


# Generate a set of seeds, using as the first seed of the set,
# optionally using it as the rng seed for subsequent seeds in the set
def batch_seeds(seed, batch_count, repeatable=False):
# use the passed seed as the initial seed of the batch
seeds = [sanitize_seed(seed)]
# take a seed expression in an input format and convert it to
# a list of integers, where possible
def parse_seed_input(seed_input: str | list | int):
if isinstance(seed_input, str):
try:
seed_input = json.loads(seed_input)
except (ValueError, TypeError):
seed_input = None

if isinstance(seed_input, int):
return [seed_input]

if isinstance(seed_input, list) and all(
type(seed) is int for seed in seed_input
):
return seed_input

raise TypeError(
"Seed input must be an integer or an array of integers in JSON format"
)


# Generate a set of seeds from an input expression for batch_count batches,
# optionally using that input as the rng seed for any randomly generated seeds.
def batch_seeds(
seed_input: str | list | int, batch_count: int, repeatable=False
):
# turn the input into a list if possible
seeds = parse_seed_input(seed_input)

# slice or pad the list to be of batch_count length
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))

if repeatable:
# use the initial seed as the rng generator seed
# set seed for the rng based on what we have so far
saved_random_state = random_getstate()
seed_random(seed)
if all(seed < 0 for seed in seeds):
seeds[0] = sanitize_seed(seeds[0])
seed_random(str(seeds))

# generate the additional seeds
for i in range(1, batch_count):
seeds.append(sanitize_seed(-1))
# generate any seeds that are unspecified
seeds = [sanitize_seed(seed) for seed in seeds]

if repeatable:
# reset the rng back to normal
Expand Down
14 changes: 9 additions & 5 deletions apps/stable_diffusion/web/ui/img2img_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def img2img_inf(
steps: int,
strength: float,
guidance_scale: float,
seed: int,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
Expand Down Expand Up @@ -230,10 +230,12 @@ def img2img_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
extra_info = {"STRENGTH": strength}
text_output = ""
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None

for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
Expand Down Expand Up @@ -617,8 +619,10 @@ def create_canvas(w, h):
visible=False,
)
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
Expand Down
13 changes: 9 additions & 4 deletions apps/stable_diffusion/web/ui/inpaint_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def inpaint_inf(
inpaint_full_res_padding: int,
steps: int,
guidance_scale: float,
seed: int,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
Expand Down Expand Up @@ -181,10 +181,13 @@ def inpaint_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
image = image_dict["image"]
mask_image = image_dict["mask"]
text_output = ""
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None

for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
Expand Down Expand Up @@ -514,8 +517,10 @@ def inpaint_api(
visible=False,
)
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
Expand Down
6 changes: 4 additions & 2 deletions apps/stable_diffusion/web/ui/lora_train_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import lora_train
from apps.stable_diffusion.src import prompt_examples, args
from apps.stable_diffusion.src import prompt_examples, args, utils
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
Expand Down Expand Up @@ -168,7 +168,9 @@
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
value=utils.parse_seed_input(args.seed)[0],
precision=0,
label="Seed",
)
device = gr.Dropdown(
elem_id="device",
Expand Down
13 changes: 9 additions & 4 deletions apps/stable_diffusion/web/ui/outpaint_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def outpaint_inf(
width: int,
steps: int,
guidance_scale: float,
seed: int,
seed: str,
batch_count: int,
batch_size: int,
scheduler: str,
Expand Down Expand Up @@ -178,7 +178,10 @@ def outpaint_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None

left = True if "left" in directions else False
right = True if "right" in directions else False
Expand Down Expand Up @@ -542,8 +545,10 @@ def outpaint_api(
visible=False,
)
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
Expand Down
13 changes: 9 additions & 4 deletions apps/stable_diffusion/web/ui/txt2img_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def txt2img_inf(
width: int,
steps: int,
guidance_scale: float,
seed: int,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
Expand Down Expand Up @@ -178,8 +178,11 @@ def txt2img_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
text_output = ""
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None

for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
Expand Down Expand Up @@ -481,8 +484,10 @@ def txt2img_api(
label="Repeatable Seeds",
)
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
Expand Down
13 changes: 9 additions & 4 deletions apps/stable_diffusion/web/ui/upscaler_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def upscaler_inf(
steps: int,
noise_level: int,
guidance_scale: float,
seed: int,
seed: str,
batch_count: int,
batch_size: int,
scheduler: str,
Expand Down Expand Up @@ -177,8 +177,11 @@ def upscaler_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
extra_info = {"NOISE LEVEL": noise_level}
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None

for current_batch in range(batch_count):
low_res_img = image
Expand Down Expand Up @@ -534,8 +537,10 @@ def upscaler_api(
visible=False,
)
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
Expand Down

0 comments on commit 289f983

Please sign in to comment.