From 0f27ba10dcc359d942b195e956b37700704a40ec Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 1 Dec 2023 08:44:03 -0800 Subject: [PATCH] Install `diffusers` before checking for token. (#2072) Summary: This PR moves the `install_diffusers()` call before checking whether HF token is in the environment. This is so users don't need to set the token at the time of installation. Pull Request resolved: https://github.com/pytorch/benchmark/pull/2072 Reviewed By: aaronenyeshi Differential Revision: D51749249 Pulled By: xuzhao9 fbshipit-source-id: 7b0b37c530af17f23cde787198e88bfea6507004 --- torchbenchmark/models/stable_diffusion_text_encoder/install.py | 2 +- torchbenchmark/models/stable_diffusion_unet/install.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchbenchmark/models/stable_diffusion_text_encoder/install.py b/torchbenchmark/models/stable_diffusion_text_encoder/install.py index a9f4576593..1208f13104 100644 --- a/torchbenchmark/models/stable_diffusion_text_encoder/install.py +++ b/torchbenchmark/models/stable_diffusion_text_encoder/install.py @@ -10,8 +10,8 @@ def load_model_checkpoint(): StableDiffusionPipeline.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, safety_checker=None) if __name__ == "__main__": + install_diffusers() if not 'HUGGING_FACE_HUB_TOKEN' in os.environ: warnings.warn("Make sure to set `HUGGINGFACE_HUB_TOKEN` so you can download weights") else: - install_diffusers() load_model_checkpoint() diff --git a/torchbenchmark/models/stable_diffusion_unet/install.py b/torchbenchmark/models/stable_diffusion_unet/install.py index a9f4576593..1208f13104 100644 --- a/torchbenchmark/models/stable_diffusion_unet/install.py +++ b/torchbenchmark/models/stable_diffusion_unet/install.py @@ -10,8 +10,8 @@ def load_model_checkpoint(): StableDiffusionPipeline.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, safety_checker=None) if __name__ == "__main__": + install_diffusers() if not 'HUGGING_FACE_HUB_TOKEN' in os.environ: warnings.warn("Make sure to set `HUGGINGFACE_HUB_TOKEN` so you can download weights") else: - install_diffusers() load_model_checkpoint()