-
Notifications
You must be signed in to change notification settings - Fork 0
/
anime.py
138 lines (118 loc) · 5.2 KB
/
anime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import modal
import modal.gpu
import logging
ONE_MINUTE = 60
# Initialize the modal app
app = modal.App(name="animegen")
# Setup logging configuration
logging.basicConfig(level=logging.INFO)
# Modal image with dependencies
image = modal.Image.debian_slim(python_version="3.12").pip_install(
"peft",
"accelerate==0.33.0",
"diffusers==0.31.0",
"fastapi==0.115.4",
"huggingface-hub[hf_transfer]==0.25.2",
"sentencepiece==0.2.0",
"torch==2.5.1",
"transformers~=4.44.0",
"bitsandbytes",
"slowapi",
"starlette",
"requests",
).env({
"HF_HUB_ENABLE_HF_TRANSFER": "1", # Faster downloads of Hugging Face models
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True" # Avoid memory segmentation
})
import os
import io
from starlette.requests import Request
MODEL_DIR = "/model" # Path inside the volume
volume = modal.Volume.from_name("animagine-xl-3.1", create_if_missing=True)
model_id = "cagliostrolab/animagine-xl-3.1"
@app.cls(image=image, gpu="A10G",timeout=8 * ONE_MINUTE, secrets=[modal.Secret.from_name("huggingface-secret"), modal.Secret.from_name("API_KEY")], volumes={MODEL_DIR: volume}, container_idle_timeout= 180)
class AnimeGen:
@modal.build()
def download_model(self):
from diffusers import DiffusionPipeline
import torch
"""Downloads the model and saves it to the Modal Volume during build."""
model_path = os.path.join(MODEL_DIR, "model_index.json")
if os.path.exists(model_path):
logging.info(" Skip download --> Model is already present on volume 👍")
else:
pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipeline.save_pretrained(MODEL_DIR)
logging.info(" Model downloaded and saved to the volume. 🥳")
@modal.enter()
def initialize(self):
"""Loads the model from the volume into GPU memory at runtime."""
import torch
from diffusers import DiffusionPipeline
self.API_KEY = os.environ["API_KEY"]
logging.info("Initializing DiffusionPipeline and loading model to GPU 🚀...")
# Mount the volume and load model
self.pipe = DiffusionPipeline.from_pretrained(
MODEL_DIR,
torch_dtype=torch.float16,
)
self.pipe.enable_model_cpu_offload()
self.pipe.enable_attention_slicing()
if hasattr(self.pipe.tokenizer, 'model_max_length'):
logging.info(f"Tokenizer max length: {self.pipe.tokenizer.model_max_length}")
logging.info(" Model successfully loaded into GPU memory. 👍")
@modal.method()
def run(self, prompt: str) -> list[bytes]:
"""Generates an image based on the given prompt."""
logging.info(f"Generating image with prompt: {prompt}")
tokens = self.pipe.tokenizer.encode(prompt)
logging.info(f"Token count: {len(tokens)}")
logging.info(f"Tokenized text: {self.pipe.tokenizer.convert_ids_to_tokens(tokens)}")
image = self.pipe(
prompt,
negative_prompt="nsfw, longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, very displeasing",
num_inference_steps=50,
).images[0]
buffer = io.BytesIO()
image.save(buffer, format="PNG", quality=100, optimize=False, compress_level=0)
return buffer.getvalue()
@modal.web_endpoint(docs=True)
def generate(self, prompt: str, request: Request):
from starlette.responses import Response
from urllib.parse import unquote
api_key = request.headers.get("X-API-KEY")
# Validate the API key
if api_key != self.API_KEY:
return Response("Unauthorized attempt to access the endpoint", status_code=401)
# Generate the image
decoded_prompt = unquote(prompt)
image_bytes = self.run.local(decoded_prompt)
# Enhanced headers to prevent any transformation
headers = {
"Content-Type": "image/png",
"Cache-Control": "no-transform, no-cache, must-revalidate, proxy-revalidate, max-age=0",
"Accept-Ranges": "none",
"Vary": "Accept-Encoding",
"Pragma": "no-cache",
"Expires": "0",
"Cross-Origin-Resource-Policy": "cross-origin",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET",
}
return Response(
content=image_bytes,
status_code=200,
media_type="image/png",
headers=headers
)
@modal.web_endpoint(docs=True)
def health(self):
from datetime import datetime, timezone
"Keeps the container warm"
return {"status": "Healthy", "timestamp": datetime.now(timezone.utc).isoformat()}
@app.function(schedule=modal.Cron("0 */4 * * *"), secrets=[modal.Secret.from_name("API_KEY"), modal.Secret.from_name("HEALTH")], image=image) # run every 4 hours
def update_keep_warm():
import requests
health_url = os.environ["HEALTH"]
health_response = requests.get(health_url)
print(f"Health check at: {health_response.json()['timestamp']}")