From 6e4fd06b87fcf2794613af01359268240b156e4b Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 27 Feb 2024 16:39:25 -0800 Subject: [PATCH] upgade to xdl --- templates/serve-stable-diffusion/main.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/templates/serve-stable-diffusion/main.py b/templates/serve-stable-diffusion/main.py index d503fdd8f..67fa8a2f0 100644 --- a/templates/serve-stable-diffusion/main.py +++ b/templates/serve-stable-diffusion/main.py @@ -2,7 +2,7 @@ from io import BytesIO import torch -from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline +from diffusers import DiffusionPipeline from fastapi import FastAPI, HTTPException from fastapi.responses import Response @@ -57,19 +57,15 @@ async def generate(self, prompt: str, img_size: int = 512): "target_num_ongoing_requests_per_replica": 1, # Target number of ongoing requests in a replica. Serve compares the actual number agasint this value and upscales or downscales. }, ) -class StableDiffusionV2: +class StableDiffusionXL: def __init__(self): # Load the stable diffusion model inside a Ray Serve Deployment. - model_id = "stabilityai/stable-diffusion-2" + model_id = "stabilityai/stable-diffusion-xl-base-1.0" - scheduler = EulerDiscreteScheduler.from_pretrained( - model_id, subfolder="scheduler" + self.pipe = DiffusionPipeline.from_pretrained( + model_id, torch_dtype=torch.float16, use_safetensors=True, variant="fp16" ) - self.pipe = StableDiffusionPipeline.from_pretrained( - model_id, scheduler=scheduler, revision="fp16", torch_dtype=torch.float16 - ) - self.pipe = self.pipe.to("cuda") - + self.pipe.to("cuda") def generate(self, prompt: str, img_size: int = 512): assert len(prompt), "prompt parameter cannot be empty" @@ -81,4 +77,4 @@ def generate(self, prompt: str, img_size: int = 512): # Bind the deployments to arguments that will be passed into its constructor. # This defines a Ray Serve application that we can run locally or deploy to production. -stable_diffusion_app = APIIngress.bind(StableDiffusionV2.bind()) +stable_diffusion_app = APIIngress.bind(StableDiffusionXL.bind())