-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathapp.py
75 lines (66 loc) · 2.85 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from operator import mod
import torch
from torch import autocast
from diffusers import StableDiffusionInpaintPipeline
import base64
from io import BytesIO
import os
import PIL
from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image
import os
import random
scheduler = DPMSolverMultistepScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
trained_betas=None,
thresholding=False,
algorithm_type="dpmsolver++",
solver_type="midpoint",
lower_order_final=True,
)
# Init is ran on server startup
# Load your model to GPU as a global variable here using the variable name "model"
def init():
global model
HF_AUTH_TOKEN = "ADD YOUR AUTH TOKEN HERE"
model = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16,use_auth_token=HF_AUTH_TOKEN).to("cuda")
model.enable_attention_slicing()
# Inference is ran for every server call
# Reference your preloaded global model variable here.
def inference(model_inputs:dict) -> dict:
global model
# Parse out your arguments
prompt = model_inputs.get('prompt', None)
if prompt == None:
return {'message': "No prompt provided"}
prompt = model_inputs.get('prompt', None)
init_image_base64 = model_inputs.get('init_image_base64', None)
if init_image_base64==None:
return {'message': "No init_image provided"}
mask_image_base64 = model_inputs.get('mask_image_base64', None)
if mask_image_base64==None:
return {'message': "No mask_image provided"}
guidance_scale = model_inputs.get("guidance_scale",7.5)
height = model_inputs.get("height",512)
width = model_inputs.get("width",512)
steps = model_inputs.get("steps",50)
seed = model_inputs.get("seed",0)
init_image_encoded = init_image_base64.encode('utf-8')
init_image_bytes = BytesIO(base64.b64decode(init_image_encoded))
init_image = PIL.Image.open(init_image_bytes)
mask_image_encoded = mask_image_base64.encode('utf-8')
mask_image_bytes = BytesIO(base64.b64decode(mask_image_encoded))
mask_image = PIL.Image.open(mask_image_bytes)
generator=torch.Generator(device="cuda").manual_seed(seed)
# Run the model
with autocast("cuda"):
image = model(prompt=prompt,image=init_image,mask_image=mask_image,height=height,width=width,num_inference_steps=steps,guidance_scale=guidance_scale,generator=generator).images[0]
buffered = BytesIO()
image.save(buffered,format="PNG")
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
# Return the results as a dictionary
return {'image_base64': image_base64}