Skip to content

Commit

Permalink
Merge pull request #97 from anyscale/sdxl
Browse files Browse the repository at this point in the history
Upgrade stable diffusion OA template to SDXL
  • Loading branch information
ericl authored Feb 28, 2024
2 parents 89f5fa9 + 6e4fd06 commit 66a4fae
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions templates/serve-stable-diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from io import BytesIO

import torch
from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
from diffusers import DiffusionPipeline
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response

Expand Down Expand Up @@ -57,19 +57,15 @@ async def generate(self, prompt: str, img_size: int = 512):
"target_num_ongoing_requests_per_replica": 1, # Target number of ongoing requests in a replica. Serve compares the actual number agasint this value and upscales or downscales.
},
)
class StableDiffusionV2:
class StableDiffusionXL:
def __init__(self):
# Load the stable diffusion model inside a Ray Serve Deployment.
model_id = "stabilityai/stable-diffusion-2"
model_id = "stabilityai/stable-diffusion-xl-base-1.0"

scheduler = EulerDiscreteScheduler.from_pretrained(
model_id, subfolder="scheduler"
self.pipe = DiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
)
self.pipe = StableDiffusionPipeline.from_pretrained(
model_id, scheduler=scheduler, revision="fp16", torch_dtype=torch.float16
)
self.pipe = self.pipe.to("cuda")

self.pipe.to("cuda")

def generate(self, prompt: str, img_size: int = 512):
assert len(prompt), "prompt parameter cannot be empty"
Expand All @@ -81,4 +77,4 @@ def generate(self, prompt: str, img_size: int = 512):

# Bind the deployments to arguments that will be passed into its constructor.
# This defines a Ray Serve application that we can run locally or deploy to production.
stable_diffusion_app = APIIngress.bind(StableDiffusionV2.bind())
stable_diffusion_app = APIIngress.bind(StableDiffusionXL.bind())

0 comments on commit 66a4fae

Please sign in to comment.