diff --git a/apps/stable_diffusion/src/utils/__init__.py b/apps/stable_diffusion/src/utils/__init__.py index 265bf17773..ff66e2ebe2 100644 --- a/apps/stable_diffusion/src/utils/__init__.py +++ b/apps/stable_diffusion/src/utils/__init__.py @@ -28,6 +28,7 @@ fetch_and_update_base_model_id, get_path_to_diffusers_checkpoint, sanitize_seed, + parse_seed_input, batch_seeds, get_path_stem, get_extended_name, diff --git a/apps/stable_diffusion/src/utils/stable_args.py b/apps/stable_diffusion/src/utils/stable_args.py index 4db5534ad8..7c17854439 100644 --- a/apps/stable_diffusion/src/utils/stable_args.py +++ b/apps/stable_diffusion/src/utils/stable_args.py @@ -66,9 +66,9 @@ def is_valid_file(arg): p.add_argument( "--seed", - type=int, + type=str, default=-1, - help="The seed to use. -1 for a random one.", + help="The seed or list of seeds to use. -1 for a random one.", ) p.add_argument( diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index e350837fa9..6cfd904c93 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -727,7 +727,8 @@ def fetch_and_update_base_model_id(model_to_run, base_model=""): # Generate and return a new seed if the provided one is not in the # supported range (including -1) -def sanitize_seed(seed): +def sanitize_seed(seed: int | str): + seed = int(seed) uint32_info = np.iinfo(np.uint32) uint32_min, uint32_max = uint32_info.min, uint32_info.max if seed < uint32_min or seed >= uint32_max: @@ -735,20 +736,48 @@ def sanitize_seed(seed): return seed -# Generate a set of seeds, using as the first seed of the set, -# optionally using it as the rng seed for subsequent seeds in the set -def batch_seeds(seed, batch_count, repeatable=False): - # use the passed seed as the initial seed of the batch - seeds = [sanitize_seed(seed)] +# take a seed expression in an input format and convert it to +# a list of integers, where possible +def parse_seed_input(seed_input: str | list | int): + if isinstance(seed_input, str): + try: + seed_input = json.loads(seed_input) + except (ValueError, TypeError): + seed_input = None + + if isinstance(seed_input, int): + return [seed_input] + + if isinstance(seed_input, list) and all( + type(seed) is int for seed in seed_input + ): + return seed_input + + raise TypeError( + "Seed input must be an integer or an array of integers in JSON format" + ) + + +# Generate a set of seeds from an input expression for batch_count batches, +# optionally using that input as the rng seed for any randomly generated seeds. +def batch_seeds( + seed_input: str | list | int, batch_count: int, repeatable=False +): + # turn the input into a list if possible + seeds = parse_seed_input(seed_input) + + # slice or pad the list to be of batch_count length + seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds)) if repeatable: - # use the initial seed as the rng generator seed + # set seed for the rng based on what we have so far saved_random_state = random_getstate() - seed_random(seed) + if all(seed < 0 for seed in seeds): + seeds[0] = sanitize_seed(seeds[0]) + seed_random(str(seeds)) - # generate the additional seeds - for i in range(1, batch_count): - seeds.append(sanitize_seed(-1)) + # generate any seeds that are unspecified + seeds = [sanitize_seed(seed) for seed in seeds] if repeatable: # reset the rng back to normal diff --git a/apps/stable_diffusion/web/ui/img2img_ui.py b/apps/stable_diffusion/web/ui/img2img_ui.py index f731a7118d..fbf0be1921 100644 --- a/apps/stable_diffusion/web/ui/img2img_ui.py +++ b/apps/stable_diffusion/web/ui/img2img_ui.py @@ -50,7 +50,7 @@ def img2img_inf( steps: int, strength: float, guidance_scale: float, - seed: int, + seed: str | int, batch_count: int, batch_size: int, scheduler: str, @@ -230,10 +230,12 @@ def img2img_inf( start_time = time.time() global_obj.get_sd_obj().log = "" generated_imgs = [] - seeds = [] - seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) extra_info = {"STRENGTH": strength} text_output = "" + try: + seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) + except TypeError as error: + raise gr.Error(str(error)) from None for current_batch in range(batch_count): out_imgs = global_obj.get_sd_obj().generate_images( @@ -617,8 +619,10 @@ def create_canvas(w, h): visible=False, ) with gr.Row(): - seed = gr.Number( - value=args.seed, precision=0, label="Seed" + seed = gr.Textbox( + value=args.seed, + label="Seed", + info="An integer or a JSON list of integers, -1 for random", ) device = gr.Dropdown( elem_id="device", diff --git a/apps/stable_diffusion/web/ui/inpaint_ui.py b/apps/stable_diffusion/web/ui/inpaint_ui.py index 73005e3021..e10895f65c 100644 --- a/apps/stable_diffusion/web/ui/inpaint_ui.py +++ b/apps/stable_diffusion/web/ui/inpaint_ui.py @@ -49,7 +49,7 @@ def inpaint_inf( inpaint_full_res_padding: int, steps: int, guidance_scale: float, - seed: int, + seed: str | int, batch_count: int, batch_size: int, scheduler: str, @@ -181,10 +181,13 @@ def inpaint_inf( start_time = time.time() global_obj.get_sd_obj().log = "" generated_imgs = [] - seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) image = image_dict["image"] mask_image = image_dict["mask"] text_output = "" + try: + seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) + except TypeError as error: + raise gr.Error(str(error)) from None for current_batch in range(batch_count): out_imgs = global_obj.get_sd_obj().generate_images( @@ -514,8 +517,10 @@ def inpaint_api( visible=False, ) with gr.Row(): - seed = gr.Number( - value=args.seed, precision=0, label="Seed" + seed = gr.Textbox( + value=args.seed, + label="Seed", + info="An integer or a JSON list of integers, -1 for random", ) device = gr.Dropdown( elem_id="device", diff --git a/apps/stable_diffusion/web/ui/lora_train_ui.py b/apps/stable_diffusion/web/ui/lora_train_ui.py index 3a8b07ebbb..6580390199 100644 --- a/apps/stable_diffusion/web/ui/lora_train_ui.py +++ b/apps/stable_diffusion/web/ui/lora_train_ui.py @@ -3,7 +3,7 @@ import gradio as gr from PIL import Image from apps.stable_diffusion.scripts import lora_train -from apps.stable_diffusion.src import prompt_examples, args +from apps.stable_diffusion.src import prompt_examples, args, utils from apps.stable_diffusion.web.ui.utils import ( available_devices, nodlogo_loc, @@ -168,7 +168,9 @@ stop_batch = gr.Button("Stop Batch") with gr.Row(): seed = gr.Number( - value=args.seed, precision=0, label="Seed" + value=utils.parse_seed_input(args.seed)[0], + precision=0, + label="Seed", ) device = gr.Dropdown( elem_id="device", diff --git a/apps/stable_diffusion/web/ui/outpaint_ui.py b/apps/stable_diffusion/web/ui/outpaint_ui.py index 2c660ebb5a..423bfb291c 100644 --- a/apps/stable_diffusion/web/ui/outpaint_ui.py +++ b/apps/stable_diffusion/web/ui/outpaint_ui.py @@ -49,7 +49,7 @@ def outpaint_inf( width: int, steps: int, guidance_scale: float, - seed: int, + seed: str, batch_count: int, batch_size: int, scheduler: str, @@ -178,7 +178,10 @@ def outpaint_inf( start_time = time.time() global_obj.get_sd_obj().log = "" generated_imgs = [] - seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) + try: + seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) + except TypeError as error: + raise gr.Error(str(error)) from None left = True if "left" in directions else False right = True if "right" in directions else False @@ -542,8 +545,10 @@ def outpaint_api( visible=False, ) with gr.Row(): - seed = gr.Number( - value=args.seed, precision=0, label="Seed" + seed = gr.Textbox( + value=args.seed, + label="Seed", + info="An integer or a JSON list of integers, -1 for random", ) device = gr.Dropdown( elem_id="device", diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py index 9a2fd452bc..8e516e95a2 100644 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_ui.py @@ -46,7 +46,7 @@ def txt2img_inf( width: int, steps: int, guidance_scale: float, - seed: int, + seed: str | int, batch_count: int, batch_size: int, scheduler: str, @@ -178,8 +178,11 @@ def txt2img_inf( start_time = time.time() global_obj.get_sd_obj().log = "" generated_imgs = [] - seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) text_output = "" + try: + seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) + except TypeError as error: + raise gr.Error(str(error)) from None for current_batch in range(batch_count): out_imgs = global_obj.get_sd_obj().generate_images( @@ -481,8 +484,10 @@ def txt2img_api( label="Repeatable Seeds", ) with gr.Row(): - seed = gr.Number( - value=args.seed, precision=0, label="Seed" + seed = gr.Textbox( + value=args.seed, + label="Seed", + info="An integer or a JSON list of integers, -1 for random", ) device = gr.Dropdown( elem_id="device", diff --git a/apps/stable_diffusion/web/ui/upscaler_ui.py b/apps/stable_diffusion/web/ui/upscaler_ui.py index f401bd6b03..302c0d5144 100644 --- a/apps/stable_diffusion/web/ui/upscaler_ui.py +++ b/apps/stable_diffusion/web/ui/upscaler_ui.py @@ -42,7 +42,7 @@ def upscaler_inf( steps: int, noise_level: int, guidance_scale: float, - seed: int, + seed: str, batch_count: int, batch_size: int, scheduler: str, @@ -177,8 +177,11 @@ def upscaler_inf( start_time = time.time() global_obj.get_sd_obj().log = "" generated_imgs = [] - seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) extra_info = {"NOISE LEVEL": noise_level} + try: + seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) + except TypeError as error: + raise gr.Error(str(error)) from None for current_batch in range(batch_count): low_res_img = image @@ -534,8 +537,10 @@ def upscaler_api( visible=False, ) with gr.Row(): - seed = gr.Number( - value=args.seed, precision=0, label="Seed" + seed = gr.Textbox( + value=args.seed, + label="Seed", + info="An integer or a JSON list of integers, -1 for random", ) device = gr.Dropdown( elem_id="device",