Skip to content

Commit

Permalink
feat: add image support to groq (#827)
Browse files Browse the repository at this point in the history
  • Loading branch information
AarushSah authored Nov 10, 2024
1 parent da1b70a commit 2f2891e
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion src/inspect_ai/model/_providers/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from groq.types.chat import (
ChatCompletion,
ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImageParam,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam,
Expand All @@ -18,6 +21,9 @@
from typing_extensions import override

from inspect_ai._util.constants import DEFAULT_MAX_RETRIES, DEFAULT_MAX_TOKENS
from inspect_ai._util.content import Content
from inspect_ai._util.images import image_as_data_uri
from inspect_ai._util.url import is_data_uri, is_http_url
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo

from .._chat_message import (
Expand Down Expand Up @@ -204,7 +210,15 @@ async def groq_chat_message(message: ChatMessage) -> ChatCompletionMessageParam:
return ChatCompletionSystemMessageParam(role="system", content=message.text)

elif isinstance(message, ChatMessageUser):
return ChatCompletionUserMessageParam(role="user", content=message.text)
content = (
message.content
if isinstance(message.content, str)
else [
await as_chat_completion_part(content)
for content in message.content
]
)
return ChatCompletionUserMessageParam(role="user", content=content)

elif isinstance(message, ChatMessageAssistant):
return ChatCompletionAssistantMessageParam(
Expand All @@ -230,6 +244,28 @@ async def groq_chat_message(message: ChatMessage) -> ChatCompletionMessageParam:
)


async def as_chat_completion_part(
content: Content,
) -> ChatCompletionContentPartParam:
if content.type == "text":
return ChatCompletionContentPartTextParam(type="text", text=content.text)
else:
# API takes URL or base64 encoded file. If it's a remote file or data URL leave it alone, otherwise encode it
image_url, detail = (
(content.image, "auto")
if isinstance(content.image, str)
else (content.image, content.detail)
)

if not is_http_url(image_url) and not is_data_uri(image_url):
image_url = await image_as_data_uri(image_url)

return ChatCompletionContentPartImageParam(
type="image_url",
image_url=dict(url=image_url, detail=detail),
)


def chat_tools(tools: List[ToolInfo]) -> List[Dict[str, Any]]:
return [
{"type": "function", "function": tool.model_dump(exclude_none=True)}
Expand Down

0 comments on commit 2f2891e

Please sign in to comment.