From 78b6e4ff4bc229effaac4f5604b0b0ff650ec658 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Tue, 28 May 2024 16:45:01 -0500 Subject: [PATCH] Update setup_venv.sh with simpler install path. (#2144) * Update requirements.txt for iree-turbine (#2130) * Fix Llama2 on CPU (#2133) * Filesystem cleanup and custom model fixes (#2127) * Initial filesystem cleanup * More filesystem cleanup * Fix some formatting issues * Address comments * Remove IREE pin (fixes exe issue) (#2126) * Diagnose a build issue * Remove IREE pin * Revert the build on pull request change * Update find links for IREE packages (#2136) * (Studio2) Refactors SD pipeline to rely on turbine-models pipeline, fixes to LLM, gitignore (#2129) * Shark Studio SDXL support, HIP driver support, simpler device info, small fixes * Fixups to llm API/UI and ignore user config files. * Small fixes for unifying pipelines. * Update requirements.txt for iree-turbine (#2130) * Fix Llama2 on CPU (#2133) * Filesystem cleanup and custom model fixes (#2127) * Fix some formatting issues * Remove IREE pin (fixes exe issue) (#2126) * Update find links for IREE packages (#2136) * Shark Studio SDXL support, HIP driver support, simpler device info, small fixes * Abstract out SD pipelines from Studio Webui (WIP) * Switch from pin to minimum torch version and fix index url * Fix device parsing. * Fix linux setup * Fix custom weights. --------- Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com> Co-authored-by: gpetters-amd <159576198+gpetters-amd@users.noreply.github.com> Co-authored-by: gpetters94 * Remove leftover merge conflict line from setup script. (#2141) * Add a few requirements for ensured parity with turbine-models requirements. (#2142) * Add scipy to requirements. Adds diffusers req and a note for torchsde. * Update linux setup script. * Move brevitas install --------- Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com> Co-authored-by: gpetters-amd <159576198+gpetters-amd@users.noreply.github.com> Co-authored-by: gpetters94 --- .github/workflows/nightly.yml | 1 + .github/workflows/test-studio.yml | 2 - .gitignore | 8 +- apps/shark_studio/api/initializers.py | 4 +- apps/shark_studio/api/llm.py | 20 +- apps/shark_studio/api/sd.py | 506 ++++++------------ apps/shark_studio/api/utils.py | 50 ++ apps/shark_studio/modules/ckpt_processing.py | 35 +- apps/shark_studio/modules/pipeline.py | 6 +- apps/shark_studio/modules/schedulers.py | 3 +- apps/shark_studio/modules/shared_cmd_opts.py | 21 +- apps/shark_studio/tests/api_test.py | 1 + apps/shark_studio/web/ui/chat.py | 4 +- apps/shark_studio/web/ui/sd.py | 8 +- apps/shark_studio/web/utils/file_utils.py | 34 +- .../web/utils/metadata/png_metadata.py | 5 +- apps/shark_studio/web/utils/tmp_configs.py | 6 +- requirements.txt | 16 +- setup_venv.ps1 | 5 +- setup_venv.sh | 53 +- shark/iree_utils/_common.py | 2 + shark/iree_utils/compile_utils.py | 11 +- shark/iree_utils/gpu_utils.py | 41 +- shark/shark_importer.py | 5 +- 24 files changed, 377 insertions(+), 470 deletions(-) diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index ad328bcbb2..1c68e12240 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -53,6 +53,7 @@ jobs: python process_skipfiles.py $env:SHARK_PACKAGE_VERSION=${{ env.package_version }} pip install -e . + pip freeze -l pyinstaller .\apps\shark_studio\shark_studio.spec mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe diff --git a/.github/workflows/test-studio.yml b/.github/workflows/test-studio.yml index 765a6bf761..9b96bf270f 100644 --- a/.github/workflows/test-studio.yml +++ b/.github/workflows/test-studio.yml @@ -81,6 +81,4 @@ jobs: source shark.venv/bin/activate pip install -r requirements.txt --no-cache-dir pip install -e . - pip uninstall -y torch - pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html python apps/shark_studio/tests/api_test.py diff --git a/.gitignore b/.gitignore index f67152b007..7d6a0d4215 100644 --- a/.gitignore +++ b/.gitignore @@ -164,7 +164,7 @@ cython_debug/ # vscode related .vscode -# Shark related artefacts +# Shark related artifacts *venv/ shark_tmp/ *.vmfb @@ -172,6 +172,7 @@ shark_tmp/ tank/dict_configs.py *.csv reproducers/ +apps/shark_studio/web/configs # ORT related artefacts cache_models/ @@ -188,6 +189,11 @@ variants.json # models folder apps/stable_diffusion/web/models/ +# model artifacts (SHARK) +*.tempfile +*.mlir +*.vmfb + # Stencil annotators. stencil_annotator/ diff --git a/apps/shark_studio/api/initializers.py b/apps/shark_studio/api/initializers.py index 48e7246df6..a8119a7d94 100644 --- a/apps/shark_studio/api/initializers.py +++ b/apps/shark_studio/api/initializers.py @@ -53,11 +53,11 @@ def initialize(): clear_tmp_imgs() from apps.shark_studio.web.utils.file_utils import ( - create_checkpoint_folders, + create_model_folders, ) # Create custom models folders if they don't exist - create_checkpoint_folders() + create_model_folders() import gradio as gr diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index a88aaa9b02..217fb6784f 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -13,7 +13,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM llm_model_map = { - "llama2_7b": { + "meta-llama/Llama-2-7b-chat-hf": { "initializer": stateless_llama.export_transformer_model, "hf_model_name": "meta-llama/Llama-2-7b-chat-hf", "compile_flags": ["--iree-opt-const-expr-hoisting=False"], @@ -155,7 +155,9 @@ def __init__( use_auth_token=hf_auth_token, ) elif not os.path.exists(self.tempfile_name): - self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"]( + self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][ + "initializer" + ]( self.hf_model_name, hf_auth_token, compile_to="torch", @@ -258,7 +260,7 @@ def format_out(results): history.append(format_out(token)) while ( - format_out(token) != llm_model_map["llama2_7b"]["stop_token"] + format_out(token) != llm_model_map[self.hf_model_name]["stop_token"] and len(history) < self.max_tokens ): dec_time = time.time() @@ -272,7 +274,7 @@ def format_out(results): self.prev_token_len = token_len + len(history) - if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]: + if format_out(token) == llm_model_map[self.hf_model_name]["stop_token"]: break for i in range(len(history)): @@ -306,7 +308,7 @@ def chat_hf(self, prompt): self.first_input = False history.append(int(token)) - while token != llm_model_map["llama2_7b"]["stop_token"]: + while token != llm_model_map[self.hf_model_name]["stop_token"]: dec_time = time.time() result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv) history.append(int(token)) @@ -317,7 +319,7 @@ def chat_hf(self, prompt): self.prev_token_len = token_len + len(history) - if token == llm_model_map["llama2_7b"]["stop_token"]: + if token == llm_model_map[self.hf_model_name]["stop_token"]: break for i in range(len(history)): if type(history[i]) != int: @@ -347,7 +349,11 @@ def llm_chat_api(InputData: dict): else: print(f"prompt : {InputData['prompt']}") - model_name = InputData["model"] if "model" in InputData.keys() else "llama2_7b" + model_name = ( + InputData["model"] + if "model" in InputData.keys() + else "meta-llama/Llama-2-7b-chat-hf" + ) model_path = llm_model_map[model_name] device = InputData["device"] if "device" in InputData.keys() else "cpu" precision = "fp16" diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 1b37384725..83574d294d 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -4,51 +4,59 @@ import os import json import numpy as np +import copy from tqdm.auto import tqdm from pathlib import Path from random import randint -from turbine_models.custom_models.sd_inference import clip, unet, vae +from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline +from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import ( + SharkSDXLPipeline, +) + + from apps.shark_studio.api.controlnet import control_adapter_map +from apps.shark_studio.api.utils import parse_device from apps.shark_studio.web.utils.state import status_label from apps.shark_studio.web.utils.file_utils import ( safe_name, get_resource_path, get_checkpoints_path, ) -from apps.shark_studio.modules.pipeline import SharkPipelineBase -from apps.shark_studio.modules.schedulers import get_schedulers -from apps.shark_studio.modules.prompt_encoding import ( - get_weighted_text_embeddings, -) + from apps.shark_studio.modules.img_processing import ( - resize_stencil, save_output_img, - resamplers, - resampler_list, ) from apps.shark_studio.modules.ckpt_processing import ( preprocessCKPT, - process_custom_pipe_weights, + save_irpa, ) -from transformers import CLIPTokenizer -from diffusers.image_processor import VaeImageProcessor - -sd_model_map = { - "clip": { - "initializer": clip.export_clip_model, - }, - "unet": { - "initializer": unet.export_unet_model, - }, - "vae_decode": { - "initializer": vae.export_vae_model, - }, + +EMPTY_SD_MAP = { + "clip": None, + "scheduler": None, + "unet": None, + "vae_decode": None, +} + +EMPTY_SDXL_MAP = { + "prompt_encoder": None, + "scheduled_unet": None, + "vae_decode": None, + "pipeline": None, + "full_pipeline": None, +} + +EMPTY_FLAGS = { + "clip": None, + "unet": None, + "vae": None, + "pipeline": None, } -class StableDiffusion(SharkPipelineBase): +class StableDiffusion: # This class is responsible for executing image generation and creating # /managing a set of compiled modules to run Stable Diffusion. The init # aims to be as general as possible, and the class will infer and compile @@ -61,66 +69,36 @@ def __init__( height: int, width: int, batch_size: int, + steps: int, + scheduler: str, precision: str, device: str, custom_vae: str = None, num_loras: int = 0, import_ir: bool = True, is_controlled: bool = False, - hf_auth_token=None, ): - self.model_max_length = 77 - self.batch_size = batch_size self.precision = precision - self.dtype = torch.float16 if precision == "fp16" else torch.float32 - self.height = height - self.width = width - self.scheduler_obj = {} - static_kwargs = { - "pipe": { - "external_weights": "safetensors", - }, - "clip": {"hf_model_name": base_model_id}, - "unet": { - "hf_model_name": base_model_id, - "unet_model": unet.UnetModel(hf_model_name=base_model_id), - "batch_size": batch_size, - # "is_controlled": is_controlled, - # "num_loras": num_loras, - "height": height, - "width": width, - "precision": precision, - "max_length": self.model_max_length, - }, - "vae_encode": { - "hf_model_name": base_model_id, - "vae_model": vae.VaeModel( - hf_model_name=custom_vae if custom_vae else base_model_id, - ), - "batch_size": batch_size, - "height": height, - "width": width, - "precision": precision, - }, - "vae_decode": { - "hf_model_name": base_model_id, - "vae_model": vae.VaeModel( - hf_model_name=custom_vae if custom_vae else base_model_id, - ), - "batch_size": batch_size, - "height": height, - "width": width, - "precision": precision, - }, - } - super().__init__(sd_model_map, base_model_id, static_kwargs, device, import_ir) + self.compiled_pipeline = False + self.base_model_id = base_model_id + self.custom_vae = custom_vae + self.is_sdxl = "xl" in self.base_model_id.lower() + if self.is_sdxl: + self.turbine_pipe = SharkSDXLPipeline + self.model_map = EMPTY_SDXL_MAP + else: + self.turbine_pipe = SharkSDPipeline + self.model_map = EMPTY_SD_MAP + external_weights = "safetensors" + max_length = 64 + target_backend, self.rt_device, triple = parse_device(device) pipe_id_list = [ safe_name(base_model_id), str(batch_size), - str(self.model_max_length), + str(max_length), f"{str(height)}x{str(width)}", precision, - self.device, + triple, ] if num_loras > 0: pipe_id_list.append(str(num_loras) + "lora") @@ -129,227 +107,116 @@ def __init__( if custom_vae: pipe_id_list.append(custom_vae) self.pipe_id = "_".join(pipe_id_list) + self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id)) + self.weights_path = Path( + os.path.join( + get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision) + ) + ) + if not os.path.exists(self.weights_path): + os.mkdir(self.weights_path) + + decomp_attn = True + attn_spec = None + if triple in ["gfx940", "gfx942", "gfx90a"]: + decomp_attn = False + attn_spec = "mfma" + elif triple in ["gfx1100", "gfx1103"]: + decomp_attn = False + attn_spec = "wmma" + elif target_backend == "llvm-cpu": + decomp_attn = False + + self.sd_pipe = self.turbine_pipe( + hf_model_name=base_model_id, + scheduler_id=scheduler, + height=height, + width=width, + precision=precision, + max_length=max_length, + batch_size=batch_size, + num_inference_steps=steps, + device=target_backend, + iree_target_triple=triple, + ireec_flags=EMPTY_FLAGS, + attn_spec=attn_spec, + decomp_attn=decomp_attn, + pipeline_dir=self.pipeline_dir, + external_weights_dir=self.weights_path, + external_weights=external_weights, + custom_vae=custom_vae, + ) print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.") - del static_kwargs gc.collect() def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): print(f"\n[LOG] Preparing pipeline...") - self.is_img2img = is_img2img - self.schedulers = get_schedulers(self.base_model_id) - - self.weights_path = os.path.join( - get_checkpoints_path(), self.safe_name(self.base_model_id) - ) - if not os.path.exists(self.weights_path): - os.mkdir(self.weights_path) - - for model in adapters: - self.model_map[model] = adapters[model] - - for submodel in self.static_kwargs: - if custom_weights: - custom_weights_params, _ = process_custom_pipe_weights(custom_weights) - if submodel not in ["clip", "clip2"]: - self.static_kwargs[submodel][ - "external_weights" - ] = custom_weights_params - else: - self.static_kwargs[submodel]["external_weight_path"] = os.path.join( - self.weights_path, submodel + ".safetensors" + self.is_img2img = False + mlirs = copy.deepcopy(self.model_map) + vmfbs = copy.deepcopy(self.model_map) + weights = copy.deepcopy(self.model_map) + + if custom_weights: + custom_weights = os.path.join( + get_checkpoints_path("checkpoints"), + safe_name(self.base_model_id.split("/")[-1]), + custom_weights, + ) + diffusers_weights_path = preprocessCKPT(custom_weights, self.precision) + for key in weights: + if key in ["scheduled_unet", "unet"]: + unet_weights_path = os.path.join( + diffusers_weights_path, + "unet", + "diffusion_pytorch_model.safetensors", ) - else: - self.static_kwargs[submodel]["external_weight_path"] = os.path.join( - self.weights_path, submodel + ".safetensors" - ) - - self.get_compiled_map(pipe_id=self.pipe_id) - print("\n[LOG] Pipeline successfully prepared for runtime.") - return + weights[key] = save_irpa(unet_weights_path, "unet.") + + elif key in ["clip", "prompt_encoder"]: + if not self.is_sdxl: + sd1_path = os.path.join( + diffusers_weights_path, "text_encoder", "model.safetensors" + ) + weights[key] = save_irpa(sd1_path, "text_encoder_model.") + else: + clip_1_path = os.path.join( + diffusers_weights_path, "text_encoder", "model.safetensors" + ) + clip_2_path = os.path.join( + diffusers_weights_path, + "text_encoder_2", + "model.safetensors", + ) + weights[key] = [ + save_irpa(clip_1_path, "text_encoder_model_1."), + save_irpa(clip_2_path, "text_encoder_model_2."), + ] + + elif key in ["vae_decode"] and weights[key] is None: + vae_weights_path = os.path.join( + diffusers_weights_path, + "vae", + "diffusion_pytorch_model.safetensors", + ) + weights[key] = save_irpa(vae_weights_path, "vae.") - def encode_prompts_weight( - self, - prompt, - negative_prompt, - do_classifier_free_guidance=True, - ): - # Encodes the prompt into text encoder hidden states. - self.load_submodels(["clip"]) - self.tokenizer = CLIPTokenizer.from_pretrained( - self.base_model_id, - subfolder="tokenizer", + vmfbs, weights = self.sd_pipe.check_prepared( + mlirs, vmfbs, weights, interactive=False ) - clip_inf_start = time.time() - - text_embeddings, uncond_embeddings = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + print(f"\n[LOG] Loading pipeline to device {self.rt_device}.") + self.sd_pipe.load_pipeline( + vmfbs, weights, self.rt_device, self.compiled_pipeline ) - - if do_classifier_free_guidance: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - pad = (0, 0) * (len(text_embeddings.shape) - 2) - pad = pad + ( - 0, - self.static_kwargs["unet"]["max_length"] - text_embeddings.shape[1], + print( + "\n[LOG] Pipeline successfully prepared for runtime. Generating images..." ) - text_embeddings = torch.nn.functional.pad(text_embeddings, pad) - - # SHARK: Report clip inference time - clip_inf_time = (time.time() - clip_inf_start) * 1000 - if self.ondemand: - self.unload_submodels(["clip"]) - gc.collect() - print(f"\n[LOG] Clip Inference time (ms) = {clip_inf_time:.3f}") - - return text_embeddings.numpy().astype(np.float16) - - def prepare_latents( - self, - generator, - num_inference_steps, - image, - strength, - ): - noise = torch.randn( - ( - self.batch_size, - 4, - self.height // 8, - self.width // 8, - ), - generator=generator, - dtype=self.dtype, - ).to("cpu") - - self.scheduler.set_timesteps(num_inference_steps) - if self.is_img2img: - init_timestep = min( - int(num_inference_steps * strength), num_inference_steps - ) - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] - latents = self.encode_image(image) - latents = self.scheduler.add_noise(latents, noise, timesteps[0].repeat(1)) - return latents, [timesteps] - else: - self.scheduler.is_scale_input_called = True - latents = noise * self.scheduler.init_noise_sigma - return latents, self.scheduler.timesteps - - def encode_image(self, input_image): - self.load_submodels(["vae_encode"]) - vae_encode_start = time.time() - latents = self.run("vae_encode", input_image) - vae_inf_time = (time.time() - vae_encode_start) * 1000 - if self.ondemand: - self.unload_submodels(["vae_encode"]) - print(f"\n[LOG] VAE Encode Inference time (ms): {vae_inf_time:.3f}") - - return latents - - def produce_img_latents( - self, - latents, - text_embeddings, - guidance_scale, - total_timesteps, - cpu_scheduling, - mask=None, - masked_image_latents=None, - return_all_latents=False, - ): - # self.status = SD_STATE_IDLE - step_time_sum = 0 - latent_history = [latents] - text_embeddings = torch.from_numpy(text_embeddings).to(self.dtype) - text_embeddings_numpy = text_embeddings.detach().numpy() - guidance_scale = torch.Tensor([guidance_scale]).to(self.dtype) - self.load_submodels(["unet"]) - for i, t in tqdm(enumerate(total_timesteps)): - step_start_time = time.time() - timestep = torch.tensor([t]).to(self.dtype).detach().numpy() - latent_model_input = self.scheduler.scale_model_input(latents, t).to( - self.dtype - ) - if mask is not None and masked_image_latents is not None: - latent_model_input = torch.cat( - [ - torch.from_numpy(np.asarray(latent_model_input)).to(self.dtype), - mask, - masked_image_latents, - ], - dim=1, - ).to(self.dtype) - if cpu_scheduling: - latent_model_input = latent_model_input.detach().numpy() - - # Profiling Unet. - # profile_device = start_profiling(file_path="unet.rdc") - noise_pred = self.run( - "unet", - [ - latent_model_input, - timestep, - text_embeddings_numpy, - guidance_scale, - ], - ) - # end_profiling(profile_device) - - if cpu_scheduling: - noise_pred = torch.from_numpy(noise_pred.to_host()) - latents = self.scheduler.step(noise_pred, t, latents).prev_sample - else: - latents = self.run("scheduler_step", (noise_pred, t, latents)) - - latent_history.append(latents) - step_time = (time.time() - step_start_time) * 1000 - # print( - # f"\n [LOG] step = {i} | timestep = {t} | time = {step_time:.2f}ms" - # ) - step_time_sum += step_time - - # if self.status == SD_STATE_CANCEL: - # break - - if self.ondemand: - self.unload_submodels(["unet"]) - gc.collect() - - avg_step_time = step_time_sum / len(total_timesteps) - print(f"\n[LOG] Average step time: {avg_step_time}ms/it") - - if not return_all_latents: - return latents - all_latents = torch.cat(latent_history, dim=0) - return all_latents - - def decode_latents(self, latents, cpu_scheduling=True): - latents_numpy = latents.to(self.dtype) - if cpu_scheduling: - latents_numpy = latents.detach().numpy() - - # profile_device = start_profiling(file_path="vae.rdc") - vae_start = time.time() - images = self.run("vae_decode", latents_numpy).to_host() - vae_inf_time = (time.time() - vae_start) * 1000 - # end_profiling(profile_device) - print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}") - - images = torch.from_numpy(images).permute(0, 2, 3, 1).float().numpy() - pil_images = self.image_processor.numpy_to_pil(images) - return pil_images + return def generate_images( self, prompt, negative_prompt, image, - scheduler, - steps, strength, guidance_scale, seed, @@ -359,69 +226,15 @@ def generate_images( control_mode, hints, ): - # TODO: Batched args - self.image_processor = VaeImageProcessor(do_convert_rgb=True) - self.scheduler = self.schedulers[scheduler] - self.ondemand = ondemand - if self.is_img2img: - image, _ = self.image_processor.preprocess(image, resample_type) - else: - image = None - - print("\n[LOG] Generating images...") - batched_args = [ - prompt, - negative_prompt, - image, - ] - for arg in batched_args: - if not isinstance(arg, list): - arg = [arg] * self.batch_size - if len(arg) < self.batch_size: - arg = arg * self.batch_size - else: - arg = [arg[i] for i in range(self.batch_size)] - - text_embeddings = self.encode_prompts_weight( + img = self.sd_pipe.generate_images( prompt, negative_prompt, + 1, + guidance_scale, + seed, + return_imgs=True, ) - - uint32_info = np.iinfo(np.uint32) - uint32_min, uint32_max = uint32_info.min, uint32_info.max - if seed < uint32_min or seed >= uint32_max: - seed = randint(uint32_min, uint32_max) - - generator = torch.manual_seed(seed) - - init_latents, final_timesteps = self.prepare_latents( - generator=generator, - num_inference_steps=steps, - image=image, - strength=strength, - ) - - latents = self.produce_img_latents( - latents=init_latents, - text_embeddings=text_embeddings, - guidance_scale=guidance_scale, - total_timesteps=final_timesteps, - cpu_scheduling=True, # until we have schedulers through Turbine - ) - - # Img latents -> PIL images - all_imgs = [] - self.load_submodels(["vae_decode"]) - for i in tqdm(range(0, latents.shape[0], self.batch_size)): - imgs = self.decode_latents( - latents=latents[i : i + self.batch_size], - cpu_scheduling=True, - ) - all_imgs.extend(imgs) - if self.ondemand: - self.unload_submodels(["vae_decode"]) - - return all_imgs + return img def shark_sd_fn_dict_input( @@ -481,6 +294,7 @@ def shark_sd_fn( control_mode = None hints = [] num_loras = 0 + import_ir = True for i in embeddings: num_loras += 1 if embeddings[i] else 0 if "model" in controlnets: @@ -514,8 +328,10 @@ def shark_sd_fn( "device": device, "custom_vae": custom_vae, "num_loras": num_loras, - "import_ir": cmd_opts.import_mlir, + "import_ir": import_ir, "is_controlled": is_controlled, + "steps": steps, + "scheduler": scheduler, } submit_prep_kwargs = { "custom_weights": custom_weights, @@ -527,8 +343,6 @@ def shark_sd_fn( "prompt": prompt, "negative_prompt": negative_prompt, "image": sd_init_image, - "steps": steps, - "scheduler": scheduler, "strength": strength, "guidance_scale": guidance_scale, "seed": seed, @@ -566,9 +380,9 @@ def shark_sd_fn( for current_batch in range(batch_count): start_time = time.time() out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs) - total_time = time.time() - start_time - text_output = f"Total image(s) generation time: {total_time:.4f}sec" - print(f"\n[LOG] {text_output}") + # total_time = time.time() - start_time + # text_output = f"Total image(s) generation time: {total_time:.4f}sec" + # print(f"\n[LOG] {text_output}") # if global_obj.get_sd_status() == SD_STATE_CANCEL: # break # else: @@ -596,13 +410,19 @@ def view_json_file(file_path): return content +def safe_name(name): + return name.replace("/", "_").replace("\\", "_").replace(".", "_") + + if __name__ == "__main__": from apps.shark_studio.modules.shared_cmd_opts import cmd_opts import apps.shark_studio.web.utils.globals as global_obj global_obj._init() - sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json")) + sd_json = view_json_file( + get_resource_path(os.path.join(cmd_opts.config_dir, "default_sd_config.json")) + ) sd_kwargs = json.loads(sd_json) for arg in vars(cmd_opts): if arg in sd_kwargs: diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index e9268aa83b..0516255d2b 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -71,6 +71,8 @@ def get_devices_by_name(driver_name): available_devices.extend(cuda_devices) rocm_devices = get_devices_by_name("rocm") available_devices.extend(rocm_devices) + hip_devices = get_devices_by_name("hip") + available_devices.extend(hip_devices) cpu_device = get_devices_by_name("cpu-sync") available_devices.extend(cpu_device) cpu_device = get_devices_by_name("cpu-task") @@ -127,6 +129,54 @@ def set_iree_runtime_flags(): set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) +def parse_device(device_str): + from shark.iree_utils.compile_utils import ( + clean_device_info, + get_iree_target_triple, + iree_target_map, + ) + + rt_driver, device_id = clean_device_info(device_str) + target_backend = iree_target_map(rt_driver) + if device_id: + rt_device = f"{rt_driver}://{device_id}" + else: + rt_device = rt_driver + + match target_backend: + case "vulkan-spirv": + triple = get_iree_target_triple(device_str) + return target_backend, rt_device, triple + case "rocm": + triple = get_rocm_target_chip(device_str) + return target_backend, rt_device, triple + case "llvm-cpu": + return "llvm-cpu", "local-task", "x86_64-linux-gnu" + + +def get_rocm_target_chip(device_str): + # TODO: Use a data file to map device_str to target chip. + rocm_chip_map = { + "6700": "gfx1031", + "6800": "gfx1030", + "6900": "gfx1030", + "7900": "gfx1100", + "MI300X": "gfx942", + "MI300A": "gfx940", + "MI210": "gfx90a", + "MI250": "gfx90a", + "MI100": "gfx908", + "MI50": "gfx906", + "MI60": "gfx906", + } + for key in rocm_chip_map: + if key in device_str: + return rocm_chip_map[key] + raise AssertionError( + f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues." + ) + + def get_all_devices(driver_name): """ Inputs: driver_name diff --git a/apps/shark_studio/modules/ckpt_processing.py b/apps/shark_studio/modules/ckpt_processing.py index 08681f6c56..433df13654 100644 --- a/apps/shark_studio/modules/ckpt_processing.py +++ b/apps/shark_studio/modules/ckpt_processing.py @@ -2,10 +2,16 @@ import json import re import requests +import torch +import safetensors +from shark_turbine.aot.params import ( + ParameterArchiveBuilder, +) from io import BytesIO from pathlib import Path from tqdm import tqdm from omegaconf import OmegaConf +from diffusers import StableDiffusionPipeline from apps.shark_studio.modules.shared_cmd_opts import cmd_opts from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( download_from_original_stable_diffusion_ckpt, @@ -14,21 +20,21 @@ ) -def get_path_to_diffusers_checkpoint(custom_weights): +def get_path_to_diffusers_checkpoint(custom_weights, precision="fp16"): path = Path(custom_weights) diffusers_path = path.parent.absolute() - diffusers_directory_name = os.path.join("diffusers", path.stem) + diffusers_directory_name = os.path.join("diffusers", path.stem + f"_{precision}") complete_path_to_diffusers = diffusers_path / diffusers_directory_name complete_path_to_diffusers.mkdir(parents=True, exist_ok=True) path_to_diffusers = complete_path_to_diffusers.as_posix() return path_to_diffusers -def preprocessCKPT(custom_weights, is_inpaint=False): - path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights) +def preprocessCKPT(custom_weights, precision="fp16", is_inpaint=False): + path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights, precision) if next(Path(path_to_diffusers).iterdir(), None): print("Checkpoint already loaded at : ", path_to_diffusers) - return + return path_to_diffusers else: print( "Diffusers' checkpoint will be identified here : ", @@ -50,8 +56,24 @@ def preprocessCKPT(custom_weights, is_inpaint=False): from_safetensors=from_safetensors, num_in_channels=num_in_channels, ) + if precision == "fp16": + pipe.to(dtype=torch.float16) pipe.save_pretrained(path_to_diffusers) + del pipe print("Loading complete") + return path_to_diffusers + + +def save_irpa(weights_path, prepend_str): + weights = safetensors.torch.load_file(weights_path) + archive = ParameterArchiveBuilder() + for key in weights.keys(): + new_key = prepend_str + key + archive.add_tensor(new_key, weights[key]) + + irpa_file = weights_path.replace(".safetensors", ".irpa") + archive.save(irpa_file) + return irpa_file def convert_original_vae(vae_checkpoint): @@ -87,6 +109,7 @@ def process_custom_pipe_weights(custom_weights): ), "checkpoint files supported can be any of [.ckpt, .safetensors] type" custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights) custom_weights_params = custom_weights + return custom_weights_params, custom_weights_tgt @@ -98,7 +121,7 @@ def get_civitai_checkpoint(url: str): base_filename = re.findall( '"([^"]*)"', response.headers["Content-Disposition"] )[0] - destination_path = Path.cwd() / (cmd_opts.ckpt_dir or "models") / base_filename + destination_path = Path.cwd() / (cmd_opts.model_dir or "models") / base_filename # we don't have this model downloaded yet if not destination_path.is_file(): diff --git a/apps/shark_studio/modules/pipeline.py b/apps/shark_studio/modules/pipeline.py index 053858c5df..2daedc3352 100644 --- a/apps/shark_studio/modules/pipeline.py +++ b/apps/shark_studio/modules/pipeline.py @@ -41,7 +41,7 @@ def __init__( self.device, self.device_id = clean_device_info(device) self.import_mlir = import_mlir self.iree_module_dict = {} - self.tmp_dir = get_resource_path(os.path.join("..", "shark_tmp")) + self.tmp_dir = get_resource_path(cmd_opts.tmp_dir) if not os.path.exists(self.tmp_dir): os.mkdir(self.tmp_dir) self.tempfiles = {} @@ -55,9 +55,7 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None: # and your model map is populated with any IR - unique model IDs and their static params, # call this method to get the artifacts associated with your map. self.pipe_id = self.safe_name(pipe_id) - self.pipe_vmfb_path = Path( - os.path.join(get_checkpoints_path(".."), self.pipe_id) - ) + self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(), self.pipe_id)) self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True) if submodel == "None": print("\n[LOG] Gathering any pre-compiled artifacts....") diff --git a/apps/shark_studio/modules/schedulers.py b/apps/shark_studio/modules/schedulers.py index 3e931b1c78..56df8973d0 100644 --- a/apps/shark_studio/modules/schedulers.py +++ b/apps/shark_studio/modules/schedulers.py @@ -101,11 +101,12 @@ def export_scheduler_model(model): scheduler_model_map = { + "PNDM": export_scheduler_model("PNDMScheduler"), + "DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"), "EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"), "EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"), "LCM": export_scheduler_model("LCMScheduler"), "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"), - "PNDM": export_scheduler_model("PNDMScheduler"), "DDPM": export_scheduler_model("DDPMScheduler"), "DDIM": export_scheduler_model("DDIMScheduler"), "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"), diff --git a/apps/shark_studio/modules/shared_cmd_opts.py b/apps/shark_studio/modules/shared_cmd_opts.py index 7992660d96..d7f5f002d5 100644 --- a/apps/shark_studio/modules/shared_cmd_opts.py +++ b/apps/shark_studio/modules/shared_cmd_opts.py @@ -339,7 +339,7 @@ def is_valid_file(arg): p.add_argument( "--output_dir", type=str, - default=None, + default=os.path.join(os.getcwd(), "generated_imgs"), help="Directory path to save the output images and json.", ) @@ -613,12 +613,27 @@ def is_valid_file(arg): ) p.add_argument( - "--ckpt_dir", + "--tmp_dir", + type=str, + default=os.path.join(os.getcwd(), "shark_tmp"), + help="Path to tmp directory", +) + +p.add_argument( + "--config_dir", type=str, - default="../models", + default=os.path.join(os.getcwd(), "configs"), + help="Path to config directory", +) + +p.add_argument( + "--model_dir", + type=str, + default=os.path.join(os.getcwd(), "models"), help="Path to directory where all .ckpts are stored in order to populate " "them in the web UI.", ) + # TODO: replace API flag when these can be run together p.add_argument( "--ui", diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py index 7bed2cb7b0..49f4482576 100644 --- a/apps/shark_studio/tests/api_test.py +++ b/apps/shark_studio/tests/api_test.py @@ -36,6 +36,7 @@ def test01_LLMSmall(self): device="cpu", precision="fp32", quantization="None", + streaming_llm=True, ) count = 0 label = "Turkishoure Turkish" diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index f41eaaaba0..54ae4a139f 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -9,6 +9,7 @@ llm_model_map, LanguageModel, ) +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts import apps.shark_studio.web.utils.globals as global_obj B_SYS, E_SYS = "", "" @@ -64,6 +65,7 @@ def chat_fn( external_weights="safetensors", use_system_prompt=prompt_prefix, streaming_llm=streaming_llm, + hf_auth_token=cmd_opts.hf_auth_token, ) history[-1][-1] = "Getting the model ready... Done" yield history, "" @@ -135,7 +137,7 @@ def view_json_file(file_obj): streaming_llm = gr.Checkbox( label="Run in streaming mode (requires recompilation)", value=True, - interactive=True, + interactive=False, ) prompt_prefix = gr.Checkbox( label="Add System Prompt", diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index 799504cb75..a4df173b1c 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -17,7 +17,6 @@ write_default_sd_config, ) from apps.shark_studio.api.sd import ( - sd_model_map, shark_sd_fn_dict_input, cancel_sd, ) @@ -45,11 +44,10 @@ import apps.shark_studio.web.utils.globals as global_obj sd_default_models = [ - "CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5", "stabilityai/stable-diffusion-2-1-base", "stabilityai/stable-diffusion-2-1", - "stabilityai/stable-diffusion-xl-1.0", + "stabilityai/stable-diffusion-xl-base-1.0", "stabilityai/sdxl-turbo", ] @@ -281,14 +279,14 @@ def base_model_changed(base_model_id): with gr.Row(): height = gr.Slider( 384, - 768, + 1024, value=cmd_opts.height, step=8, label="\U00002195\U0000FE0F Height", ) width = gr.Slider( 384, - 768, + 1024, value=cmd_opts.width, step=8, label="\U00002194\U0000FE0F Width", diff --git a/apps/shark_studio/web/utils/file_utils.py b/apps/shark_studio/web/utils/file_utils.py index 0f1953f5ac..3619055676 100644 --- a/apps/shark_studio/web/utils/file_utils.py +++ b/apps/shark_studio/web/utils/file_utils.py @@ -47,7 +47,7 @@ def write_default_sd_config(path): def safe_name(name): - return name.replace("/", "_").replace("-", "_") + return name.split("/")[-1].replace("-", "_") def get_path_stem(path): @@ -66,33 +66,39 @@ def get_resource_path(path): def get_configs_path() -> Path: - configs = get_resource_path(os.path.join("..", "configs")) + configs = get_resource_path(cmd_opts.config_dir) if not os.path.exists(configs): os.mkdir(configs) - return Path(get_resource_path("../configs")) + return Path(configs) def get_generated_imgs_path() -> Path: - return Path( - cmd_opts.output_dir - if cmd_opts.output_dir - else get_resource_path("../generated_imgs") - ) + outputs = get_resource_path(cmd_opts.output_dir) + if not os.path.exists(outputs): + os.mkdir(outputs) + return Path(outputs) + + +def get_tmp_path() -> Path: + tmpdir = get_resource_path(cmd_opts.model_dir) + if not os.path.exists(tmpdir): + os.mkdir(tmpdir) + return Path(tmpdir) def get_generated_imgs_todays_subdir() -> str: return dt.now().strftime("%Y%m%d") -def create_checkpoint_folders(): +def create_model_folders(): dir = ["checkpoints", "vae", "lora", "vmfb"] - if not os.path.isdir(cmd_opts.ckpt_dir): + if not os.path.isdir(cmd_opts.model_dir): try: - os.makedirs(cmd_opts.ckpt_dir) + os.makedirs(cmd_opts.model_dir) except OSError: sys.exit( - f"Invalid --ckpt_dir argument, " - f"{cmd_opts.ckpt_dir} folder does not exist, and cannot be created." + f"Invalid --model_dir argument, " + f"{cmd_opts.model_dir} folder does not exist, and cannot be created." ) for root in dir: @@ -100,7 +106,7 @@ def create_checkpoint_folders(): def get_checkpoints_path(model_type=""): - return get_resource_path(os.path.join(cmd_opts.ckpt_dir, model_type)) + return get_resource_path(os.path.join(cmd_opts.model_dir, model_type)) def get_checkpoints(model_type="checkpoints"): diff --git a/apps/shark_studio/web/utils/metadata/png_metadata.py b/apps/shark_studio/web/utils/metadata/png_metadata.py index 72f663f246..d1cadc1e00 100644 --- a/apps/shark_studio/web/utils/metadata/png_metadata.py +++ b/apps/shark_studio/web/utils/metadata/png_metadata.py @@ -3,9 +3,8 @@ from apps.shark_studio.web.utils.file_utils import ( get_checkpoint_pathfile, ) -from apps.shark_studio.api.sd import ( - sd_model_map, -) +from apps.shark_studio.api.sd import EMPTY_SD_MAP as sd_model_map + from apps.shark_studio.modules.schedulers import ( scheduler_model_map, ) diff --git a/apps/shark_studio/web/utils/tmp_configs.py b/apps/shark_studio/web/utils/tmp_configs.py index 4415276ea3..ebbc4ae6af 100644 --- a/apps/shark_studio/web/utils/tmp_configs.py +++ b/apps/shark_studio/web/utils/tmp_configs.py @@ -2,7 +2,9 @@ import shutil from time import time -shark_tmp = os.path.join(os.getcwd(), "shark_tmp/") +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + +shark_tmp = cmd_opts.tmp_dir # os.path.join(os.getcwd(), "shark_tmp/") def clear_tmp_mlir(): @@ -15,7 +17,7 @@ def clear_tmp_mlir(): and filename.endswith(".mlir") ] for filename in mlir_files: - os.remove(shark_tmp + filename) + os.remove(os.path.join(shark_tmp, filename)) print(f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds.") diff --git a/requirements.txt b/requirements.txt index c2a598978d..327c36b131 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,14 @@ --f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --f https://openxla.github.io/iree/pip-release-links.html +-f https://download.pytorch.org/whl/nightly/cpu +-f https://iree.dev/pip-release-links.html --pre setuptools wheel -torch==2.3.0 +torch>=2.3.0 shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main -turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@main#subdirectory=models +turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-unify-sd#subdirectory=models + # SHARK Runner tqdm @@ -17,8 +18,6 @@ google-cloud-storage # Testing pytest -pytest-xdist -pytest-forked Pillow parameterized @@ -26,8 +25,10 @@ parameterized #accelerate is now required for diffusers import from ckpt. accelerate scipy +transformers==4.37.1 +torchsde # Required for Stable Diffusion SDE schedulers. ftfy -gradio==4.19.2 +gradio==4.29.0 altair omegaconf # 0.3.2 doesn't have binaries for arm64 @@ -35,6 +36,7 @@ safetensors==0.3.1 py-cpuinfo pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions mpmath==1.3.0 +optimum # Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors pefile diff --git a/setup_venv.ps1 b/setup_venv.ps1 index c67b8fc83b..40f26ecf08 100644 --- a/setup_venv.ps1 +++ b/setup_venv.ps1 @@ -88,8 +88,7 @@ else {python -m venv .\shark.venv\} .\shark.venv\Scripts\activate python -m pip install --upgrade pip pip install wheel -pip install -r requirements.txt -# remove this when windows DLL issues are fixed from LLVM changes -pip install --force-reinstall https://github.com/openxla/iree/releases/download/candidate-20240326.843/iree_compiler-20240326.843-cp311-cp311-win_amd64.whl https://github.com/openxla/iree/releases/download/candidate-20240326.843/iree_runtime-20240326.843-cp311-cp311-win_amd64.whl +pip install --pre -r requirements.txt +pip install -e . Write-Host "Source your venv with ./shark.venv/Scripts/activate" diff --git a/setup_venv.sh b/setup_venv.sh index 64f769d794..9578cefb16 100755 --- a/setup_venv.sh +++ b/setup_venv.sh @@ -49,28 +49,18 @@ Red=`tput setaf 1` Green=`tput setaf 2` Yellow=`tput setaf 3` +RUNTIME="https://iree.dev/pip-release-links.html" +PYTORCH_URL="https://download.pytorch.org/whl/nightly/cpu/" + # Upgrade pip and install requirements. $PYTHON -m pip install --upgrade pip || die "Could not upgrade pip" -$PYTHON -m pip install --upgrade -r "$TD/requirements.txt" -if [[ $(uname -s) = 'Darwin' ]]; then - echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch." - $PYTHON -m pip uninstall -y timm #TEMP FIX FOR MAC - $PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/ -else - $PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ - if [ $? -eq 0 ];then - echo "Successfully Installed torch-mlir" - else - echo "Could not install torch-mlir" >&2 - fi -fi -if [[ -z "${USE_IREE}" ]]; then - rm .use-iree - RUNTIME="https://nod-ai.github.io/SRT/pip-release-links.html" -else - touch ./.use-iree - RUNTIME="https://openxla.github.io/iree/pip-release-links.html" +$PYTHON -m pip install --upgrade --pre torch torchvision torchaudio --index-url $PYTORCH_URL +if [[ -z "${NO_BREVITAS}" ]]; then + $PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@dev fi +$PYTHON -m pip install --pre --upgrade -r "$TD/requirements.txt" + + if [[ -z "${NO_BACKEND}" ]]; then echo "Installing ${RUNTIME}..." $PYTHON -m pip install --pre --upgrade --no-index --find-links ${RUNTIME} iree-compiler iree-runtime @@ -78,31 +68,8 @@ else echo "Not installing a backend, please make sure to add your backend to PYTHONPATH" fi -if [[ $(uname -s) = 'Darwin' ]]; then - PYTORCH_URL=https://download.pytorch.org/whl/nightly/torch/ -else - PYTORCH_URL=https://download.pytorch.org/whl/nightly/cpu/ -fi - -$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f ${PYTORCH_URL} - -if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then - T_VER=$($PYTHON -m pip show torch | grep Version) - T_VER_MIN=${T_VER:14:12} - TV_VER=$($PYTHON -m pip show torchvision | grep Version) - TV_VER_MAJ=${TV_VER:9:6} - $PYTHON -m pip uninstall -y torchvision - $PYTHON -m pip install torchvision==${TV_VER_MAJ}${T_VER_MIN} --no-deps -f https://download.pytorch.org/whl/nightly/cpu/torchvision/ - if [ $? -eq 0 ];then - echo "Successfully Installed torch + cu118." - else - echo "Could not install torch + cu118." >&2 - fi -fi +$PYTHON -m pip install --no-warn-conflicts -e . -if [[ -z "${NO_BREVITAS}" ]]; then - $PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@dev -fi if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then echo "${Green}Before running examples activate venv with:" diff --git a/shark/iree_utils/_common.py b/shark/iree_utils/_common.py index c58405b46e..1d022f67e4 100644 --- a/shark/iree_utils/_common.py +++ b/shark/iree_utils/_common.py @@ -76,6 +76,7 @@ def get_supported_device_list(): "vulkan": "vulkan", "metal": "metal", "rocm": "rocm", + "hip": "hip", "intel-gpu": "level_zero", } @@ -94,6 +95,7 @@ def iree_target_map(device): "vulkan": "vulkan-spirv", "metal": "metal", "rocm": "rocm", + "hip": "rocm", "intel-gpu": "opencl-spirv", } diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 5fd1d4006a..f93c8fef2e 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -62,13 +62,16 @@ def get_iree_device_args(device, extra_args=[]): from shark.iree_utils.gpu_utils import get_iree_rocm_args return get_iree_rocm_args(device_num=device_num, extra_args=extra_args) + if device == "hip": + from shark.iree_utils.gpu_utils import get_iree_rocm_args + return get_iree_rocm_args(device_num=device_num, extra_args=extra_args, hip_driver=True) return [] def get_iree_target_triple(device): args = get_iree_device_args(device) for flag in args: - if "triple" in flag.split("-"): - triple = flag.split("=") + if "triple" in flag: + triple = flag.split("=")[-1] return triple return "" @@ -89,9 +92,9 @@ def clean_device_info(raw_device): if len(device_id) <= 2: device_id = int(device_id) - if device not in ["rocm", "vulkan"]: + if device not in ["hip", "rocm", "vulkan"]: device_id = None - if device in ["rocm", "vulkan"] and device_id == None: + if device in ["hip", "rocm", "vulkan"] and device_id == None: device_id = 0 return device, device_id diff --git a/shark/iree_utils/gpu_utils.py b/shark/iree_utils/gpu_utils.py index 0eba67ff53..db6ef14e34 100644 --- a/shark/iree_utils/gpu_utils.py +++ b/shark/iree_utils/gpu_utils.py @@ -52,7 +52,7 @@ def check_rocm_device_arch_in_args(extra_args): return None -def get_rocm_device_arch(device_num=0, extra_args=[]): +def get_rocm_device_arch(device_num=0, extra_args=[], hip_driver=False): # ROCM Device Arch selection: # 1 : User given device arch using `--iree-rocm-target-chip` flag # 2 : Device arch from `iree-run-module --dump_devices=rocm` for device on index @@ -68,15 +68,23 @@ def get_rocm_device_arch(device_num=0, extra_args=[]): arch_in_device_dump = None # get rocm arch from iree dump devices - def get_devices_info_from_dump(dump): + def get_devices_info_from_dump(dump, driver): from os import linesep - - dump_clean = list( - filter( - lambda s: "--device=rocm" in s or "gpu-arch-name:" in s, - dump.split(linesep), + + if driver == "hip": + dump_clean = list( + filter( + lambda s: "AMD" in s, + dump.split(linesep), + ) + ) + else: + dump_clean = list( + filter( + lambda s: f"--device={driver}" in s or "gpu-arch-name:" in s, + dump.split(linesep), + ) ) - ) arch_pairs = [ ( dump_clean[i].split("=")[1].strip(), @@ -87,16 +95,17 @@ def get_devices_info_from_dump(dump): return arch_pairs dump_device_info = None + driver = "hip" if hip_driver else "rocm" try: dump_device_info = run_cmd( - "iree-run-module --dump_devices=rocm", raise_err=True + "iree-run-module --dump_devices=" + driver, raise_err=True ) except Exception as e: - print("could not execute `iree-run-module --dump_devices=rocm`") + print("could not execute `iree-run-module --dump_devices=" + driver + "`") if dump_device_info is not None: device_num = 0 if device_num is None else device_num - device_arch_pairs = get_devices_info_from_dump(dump_device_info[0]) + device_arch_pairs = get_devices_info_from_dump(dump_device_info[0], driver) if len(device_arch_pairs) > device_num: # can find arch in the list arch_in_device_dump = device_arch_pairs[device_num][1] @@ -107,24 +116,22 @@ def get_devices_info_from_dump(dump): default_rocm_arch = "gfx1100" print( "Did not find ROCm architecture from `--iree-rocm-target-chip` flag" - "\n or from `iree-run-module --dump_devices=rocm` command." + "\n or from `iree-run-module --dump_devices` command." f"\nUsing {default_rocm_arch} as ROCm arch for compilation." ) return default_rocm_arch # Get the default gpu args given the architecture. -def get_iree_rocm_args(device_num=0, extra_args=[]): +def get_iree_rocm_args(device_num=0, extra_args=[], hip_driver=False): ireert.flags.FUNCTION_INPUT_VALIDATION = False - rocm_flags = ["--iree-rocm-link-bc=true"] - + rocm_flags = [] if check_rocm_device_arch_in_args(extra_args) is None: - rocm_arch = get_rocm_device_arch(device_num, extra_args) + rocm_arch = get_rocm_device_arch(device_num, extra_args, hip_driver=hip_driver) rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}") return rocm_flags - # Some constants taken from cuda.h CUDA_SUCCESS = 0 CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16 diff --git a/shark/shark_importer.py b/shark/shark_importer.py index 3d585a8d0d..e05b81fa0d 100644 --- a/shark/shark_importer.py +++ b/shark/shark_importer.py @@ -6,6 +6,7 @@ import os import hashlib +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts def create_hash(file_name): with open(file_name, "rb") as f: @@ -120,7 +121,7 @@ def import_mlir( is_dynamic=False, tracing_required=False, func_name="forward", - save_dir="./shark_tmp/", + save_dir=cmd_opts.tmp_dir, #"./shark_tmp/", mlir_type="linalg", ): if self.frontend in ["torch", "pytorch"]: @@ -806,7 +807,7 @@ def save_mlir( model_name + "_" + frontend + "_" + mlir_dialect + ".mlir" ) if dir == "": - dir = os.path.join(".", "shark_tmp") + dir = cmd_opts.tmp_dir, #os.path.join(".", "shark_tmp") mlir_path = os.path.join(dir, model_name_mlir) print(f"saving {model_name_mlir} to {dir}") if not os.path.exists(dir):