Skip to content

Commit

Permalink
Update install.py
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim authored Jul 13, 2023
1 parent 9efa3ee commit 623ac4b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchbenchmark/models/stable_diffusion/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
MODEL_NAME = "stabilityai/stable-diffusion-2"

def load_model_checkpoint():
from diffusers import StableDiffusionPipeline
StableDiffusionPipeline.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, safety_checker=None)

def main():
if not 'HUGGING_FACE_HUB_TOKEN' in os.environ:
return NotImplementedError("Make sure to set `HUGGINGFACE_HUB_TOKEN` so you can download weights")
else:
install_diffusers()
from diffusers import StableDiffusionPipeline
load_model_checkpoint()

if __name__ == "__main__":
main()
main()

0 comments on commit 623ac4b

Please sign in to comment.