diff --git a/apps/shark_studio/api/shark_api.py b/apps/shark_studio/api/shark_api.py index 93cf255e92..f772169df6 100644 --- a/apps/shark_studio/api/shark_api.py +++ b/apps/shark_studio/api/shark_api.py @@ -1,51 +1,26 @@ - # Internal API +pipelines = { + "sd1.5": ("", None), + "sd2": ("", None), + "sdxl": ("", None), + "sd3": ("", None), +} -# Used for filenames as well as the key for the global cache -def safe_name(): - pass - -def local_path(): - pass -def generate_sd_vmfb( - model: str, +# Used for filenames as well as the key for the global cache +def safe_name( + model_name: str, height: int, width: int, - steps: int, - strength: float, - guidance_scale: float, - batch_size: int = 1, - base_model_id: str, - precision: str, - controlled: bool, - **kwargs, + batch_size: int, ): pass -def load_sd_vmfb( - model: str, - weight_file: str, - height: int, - width: int, - steps: int, - strength: float, - guidance_scale: float, - batch_size: int = 1, - base_model: str, - precision: str, - controlled: bool, - try_download: bool, - **kwargs, -): - # Check if the file is already loaded and cached - # Check if the file already exists on disk - # Try to download from the web - # Generate the vmfb (generate_sd_vmfb) - # Load the vmfb and weights - # Return wrapper + +def local_path(): pass + # External API def generate_images( prompt: str, @@ -78,123 +53,106 @@ def generate_images( # Handle img2img if not isinstance(sd_init_image, list): - sd_init_image = [sd_init_image] + sd_init_image = [sd_init_image] * batch_count is_img2img = True if sd_init_image[0] is not None else False # Generate seed if < 0 # TODO + # Cache dir + # TODO + pipeline_dir = None + # Sanity checks - # Scheduler - # Base model + assert scheduler in ["EulerDiscrete"] + assert base_model in ["sd1.5", "sd2", "sdxl", "sd3"] + assert precision in ["fp16", "fp32"] + assert device in [ + "cpu", + "vulkan", + "rocm", + "hip", + "cuda", + ] # and (IREE check if the device exists) + assert resample_type in ["Nearest Neighbor"] + # Custom weights + # TODO # Custom VAE - # Precision - # Device + # TODO # Target triple - # Resample type # TODO - adapters = {} - is_controlled = False - control_mode = None - hints = [] - num_loras = 0 - import_ir = True - - # Populate model map - if model == "sd1.5": - submodels = { - "clip": None, - "scheduler": None, - "unet": None, - "vae_decode": None, - } - elif model == "sd2": - submodels = { - "clip": None, - "scheduler": None, - "unet": None, - "vae_decode": None, - } - elif model == "sdxl": - submodels = { - "prompt_encoder": None, - "scheduled_unet": None, - "vae_decode": None, - "pipeline": None, - "full_pipeline": None, - } - elif model == "sd3": + # (Re)initialize pipeline + pipeline_args = { + "height": height, + "width": width, + "batch_size": batch_size, + "precision": precision, + "device": device, + "target_triple": target_triple, + } + (existing_args, pipeline) = pipelines[base_model] + if not existing_args or not pipeline or not pipeline_args == existing_args: + # TODO: Initialize new pipeline + if base_model == "sd1.5": + pass + elif base_model == "sd2": + new_pipeline = SharkSDPipeline( + hf_model_name="stabilityai/stable-diffusion-2-1", + scheduler_id=scheduler, + height=height, + width=width, + precision=precision, + max_length=64, + batch_size=batch_size, + num_inference_steps=steps, + device=device, # TODO: Get the IREE device ID? + iree_target_triple=target_triple, + ireec_flags={}, + attn_spec=None, # TODO: Find a better way to figure this out than hardcoding + decomp_attn=True, # TODO: Ditto + pipeline_dir=pipeline_dir, + external_weights_dir=weights, # TODO: Are both necessary still? + external_weights=weights, + custom_vae=custom_vae, + ) + elif base_model == "sdxl": + pass + elif base_model == "sd3": + pass + # existing_args = pipeline_args pass - # TODO: generate and load submodel vmfbs - for submodel in submodels: - submodels[submodel] = load_sd_vmfb( - submodel, - custom_weights, - height, - width, - steps, - strength, - guidance_scale, - batch_size, - model, - precision, - not controlnets.keys(), - True, - ) - - generated_imgs = [] + generated_images = [] for current_batch in range(batch_count): - # TODO: Batch size > 1 - - # TODO: random sample (or img2img input) - sample = None - - # TODO: encode input - prompt_embeds, negative_prompt_embeds = encode(prompt, negative_prompt) - start_time = time.time() for t in range(steps): - - # Prepare latents - - # Scale model input - latent_model_input = submodels["scheduler"].scale_model_input( - sample, - t - ) - # Run unet - latents = submodels["unet"]( - latent_model_input, - t, - (negative_prompt_embeds, prompt_embeds), - guidance_scale, + out_images = pipeline.generate_images( + prompt=prompt, + negative_prompt=negative_prompt, + image=sd_init_image[current_batch], + strength=strength, + guidance_scale=guidance_scale, + seed=seed, + ondemand=ondemand, + resample_type=resample_type, + control_mode=control_mode, + hints=hints, ) - # Step scheduler - sample = submodels["scheduler"].step( - latents, - t, - sample - ) - - # VAE decode - out_img = submodels["vae_decode"]( - sample - ) - # Processing time total_time = time.time() - start_time # text_output = f"Total image(s) generation time: {total_time:.4f}sec" # print(f"\n[LOG] {text_output}") # TODO: Add to output list - generated_imgs.append(out_img) + if not isinstance(out_images, list): + out_images = [out_images] + generated_images.extend(out_images) # TODO: Allow the user to halt the process - return generated_imgs \ No newline at end of file + return generated_images