-
Notifications
You must be signed in to change notification settings - Fork 0
/
computations.py
91 lines (73 loc) · 2.47 KB
/
computations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import requests
import base64
def generate_img(api_key, prompt, output_img_path):
engine_id = "stable-diffusion-v3"
if engine_id == "stable-diffusion-v1-6":
api_host = 'https://api.stability.ai'
response = requests.post(
f"{api_host}/v1/generation/{engine_id}/text-to-image",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key}"
},
json={
"text_prompts": [
{
"text": prompt
}
],
"cfg_scale": 7,
"height": 512,
"width": 512,
"samples": 1,
"steps": 30,
},
)
if response.status_code != 200:
raise Exception("Non-200 response: " + str(response.text))
data = response.json()
path = ""
for i, image in enumerate(data["artifacts"]):
path = f"{output_img_path.split('.')[0]}_{i}.{output_img_path.split('.')[1]}"
with open(path, "wb") as f:
f.write(base64.b64decode(image["base64"]))
return path
elif engine_id == "stable-diffusion-v3":
response = requests.post(
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
headers={
"authorization": f"Bearer {api_key}",
"accept": "image/*"
},
files={
"none": ''
},
data={
"prompt": prompt,
"output_format": "png"
},
)
if response.status_code == 200:
with open(output_img_path, 'wb') as file:
file.write(response.content)
else:
raise Exception(str(response.json()))
return output_img_path
else:
raise Exception(f"Invalid engine_id: {engine_id}")
def compute(api_key, prompt):
"""
Generates images based on an input prompt using the Stability AI API.
Input:
api_key: your Stability AI API key.
prompt: the input prompt.
Output:
output_img: path to the generated image.
"""
output_img_path = f"sd_generated_image.png"
output_img = generate_img(api_key, prompt, output_img_path)
return {"output_img": output_img}
def test():
"""Test the compute function."""
print("Running test")