-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathllm_node.py
140 lines (123 loc) · 4.94 KB
/
llm_node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import base64
from io import BytesIO
import numpy as np
from PIL import Image
from openai import OpenAI
class LLMImageDescription:
def __init__(self):
self.output_dir = "output"
self.type = "output"
self._client = None
self._current_api_key = None
self._current_api_url = None
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"model": (["gpt-4o", "claude-3-5-sonnet-20241022", "gemini-1.5-pro-latest", "gemini-2.0-flash-exp"],),
"api_url": ("STRING", {
"default": "",
"multiline": False,
"password": True
}),
"api_key": ("STRING", {
"default": "",
"multiline": False,
"password": True
}),
"prompt_template": ("STRING", {
"default": "Please describe this image in detail:",
"multiline": True
})
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("description",)
FUNCTION = "process_image"
CATEGORY = "image/text"
def get_client(self, api_key, api_url=None):
"""Get or create OpenAI client"""
# 检查参数是否发生变化
if (self._client and
(self._current_api_key != api_key or
self._current_api_url != api_url)):
# 参数变化,清除旧客户端
self._client = None
if not self._client:
kwargs = {"api_key": api_key}
if api_url:
kwargs["base_url"] = api_url
# 创建新客户端并保存当前参数
self._client = OpenAI(**kwargs)
self._current_api_key = api_key
self._current_api_url = api_url
return self._client
@staticmethod
def convert_image_to_base64(image):
# Convert PyTorch tensor to PIL Image
image = image.cpu().numpy()
image = (image * 255).astype(np.uint8)
if image.shape[0] == 3: # If image is in CHW format
image = np.transpose(image, (1, 2, 0))
pil_image = Image.fromarray(image)
# Convert PIL Image to base64
buffered = BytesIO()
pil_image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
@staticmethod
def process_with_openai_compatible(prompt_template, base64_image, client, model):
"""Process image using OpenAI API format"""
try:
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt_template
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}"
}
}
]
}
],
max_tokens=300
)
return response.choices[0].message.content
except Exception as e:
print(f"API error: {str(e)}")
return f"Error: API request failed - {str(e)}"
def process_image(self, image, model, api_url, api_key, prompt_template):
try:
# Convert the first image in batch to base64
if len(image.shape) == 4:
image = image[0]
base64_image = self.convert_image_to_base64(image)
# Get appropriate API URL based on model
if not api_url:
api_urls = {
"gpt-4o": "https://api.openai.com/v1/chat/completions",
"claude-3-5-sonnet-20241022": "https://api.anthropic.com/v1/messages",
"gemini-1.5-pro-latest": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest:generateContent",
"gemini-2.0-flash-exp": "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent"
}
api_url = api_urls.get(model, "")
# Process using unified OpenAI SDK
client = self.get_client(api_key, api_url)
description = self.process_with_openai_compatible(prompt_template, base64_image, client, model)
return (description,)
except Exception as e:
print(f"Error in image description: {str(e)}")
return (f"Error: Failed to generate image description. {str(e)}",)
def __del__(self):
"""Cleanup client on deletion"""
if self._client:
self._client.close()