Skip to content

Commit

Permalink
Fixed streaming + web agent with proper pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
winternewt committed Jan 9, 2025
1 parent fb21c31 commit fd9164f
Show file tree
Hide file tree
Showing 9 changed files with 320 additions and 2,909 deletions.
2 changes: 1 addition & 1 deletion core/just_agents/interfaces/streaming_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def sse_wrap(data: Union[Dict[str, Any], str], event: Optional[str] = None) -> s

# Append a blank line to separate events
lines.append("")
return "\n".join(lines)
return "\n".join(lines) + "\n"

@staticmethod
def sse_parse(sse_text: str) -> Dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion core/just_agents/llm_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class ModelOptions(BaseModel):
model: str = Field(
...,
examples=["gpt-4o-mini"],
examples=["groq/llama-3.3-70b-versatile","gpt-4o-mini"],
description="LLM model name"
)
temperature: Optional[float] = Field(
Expand Down
6 changes: 4 additions & 2 deletions core/just_agents/protocols/openai_streaming.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from just_agents.interfaces.streaming_protocol import IAbstractStreamingProtocol

import json
import time

class OpenaiStreamingProtocol(IAbstractStreamingProtocol):
Expand All @@ -15,7 +15,9 @@ def get_chunk(self, index: int, delta: str, options: dict):
"choices": [{"delta": {"content": delta}}],
}
return self.sse_wrap(chunk)
# return f"data: {json.dumps(chunk)}\n\n"


def done(self):
return self.sse_wrap(self.stop)
return self.sse_wrap(self.stop)
# return "\ndata: [DONE]\n\n"
164 changes: 82 additions & 82 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion web/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pip install just-agents-web
## Quick Start

```python
from just_agents_web import create_app
from just_agents.web import create_app
from just_agents.simple.chat_agent import ChatAgent

agent = ChatAgent(...)
Expand Down
61 changes: 61 additions & 0 deletions web/just_agents/web/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Type, TypeVar, Any, List, Union, Optional, Literal, AsyncGenerator, cast
from pydantic import BaseModel, Field, HttpUrl

from just_agents.protocols.litellm_protocol import Message, TextContent
from just_agents.llm_options import ModelOptions

from openai.types import CompletionUsage
from openai.types.chat.chat_completion import ChatCompletion, Choice, ChatCompletionMessage

class ModelOptionsExt(ModelOptions):
api_key: str = Field(None, examples=["openai_api_key"])

class ChatCompletionRequest(ModelOptions):
messages: List[Message] = Field(..., examples=[[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What drug interactions of rapamycin are you aware of? What are these interactions ?"}
]])
n: Optional[int] = Field(1, ge=1)
stream: Optional[bool] = Field(default=False, examples=[True])
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(None, ge=1)
logit_bias: Optional[dict] = Field(None, examples=[None])
user: Optional[str] = Field(None, examples=[None])

class ResponseMessage(ChatCompletionMessage):
role: Optional[Literal["system", "user", "assistant", "tool"]] = None

class ChatCompletionChoice(Choice):
finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]] = None
# text: Optional[str] = Field(default=None, alias="message.content")
message : Optional[ResponseMessage]

class ChatCompletionChoiceChunk(ChatCompletionChoice):
delta: ResponseMessage = Field(default=None)
message: Optional[ResponseMessage] = Field(default=None, exclude=True) #hax

class ChatCompletionUsage(CompletionUsage):
prompt_tokens: int = Field(default=0)
completion_tokens: int = Field(default=0)
total_tokens: int = Field(default=0)
pass

class ChatCompletionResponse(ChatCompletion):
# id: str
object: Literal["chat.completion", "chat.completion.chunk"]
created: Union[int,float]
# model: str
choices: List[ChatCompletionChoice]
usage: Optional[ChatCompletionUsage] = Field(default=None)


class ChatCompletionChunkResponse(ChatCompletionResponse):
choices: List[ChatCompletionChoiceChunk]

class ErrorResponse(BaseModel):
class ErrorDetails(BaseModel):
message: str = Field(...)
type: str = Field("server_error")
code: str = Field("internal_server_error")

error: ErrorDetails
206 changes: 109 additions & 97 deletions web/just_agents/web/rest_api.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import os
import json
import time
import asyncio

from pathlib import Path
from typing import Optional, List, Dict, Any, Union
from fastapi import FastAPI
from typing import Optional, List, Dict, Any, Union, AsyncGenerator

from just_agents.base_agent import BaseAgent
from just_agents.interfaces.agent import IAgent
from starlette.responses import StreamingResponse
from just_agents.web.streaming import async_wrap
from just_agents.web.models import (
ChatCompletionRequest, TextContent, ChatCompletionChoiceChunk, ChatCompletionChunkResponse,
ChatCompletionResponse, ChatCompletionChoice, ChatCompletionUsage, ResponseMessage, ErrorResponse
)
from dotenv import load_dotenv
from just_agents.interfaces.streaming_protocol import IAbstractStreamingProtocol
from fastapi.middleware.cors import CORSMiddleware
from pycomfort.logging import log_function
import yaml
import os
from pycomfort.logging import log_function
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from eliot import log_call, log_message
import json



class AgentRestAPI(FastAPI):
Expand Down Expand Up @@ -91,72 +97,75 @@ def _routes_config(self):
)
# Register routes
self.get("/")(self.default)
self.post("/v1/chat/completions")(self.chat_completions)


self.post("/v1/chat/completions", description="OpenAI compatible chat completions")(self.chat_completions)


def _clean_messages(self, request: dict):
for message in request["messages"]:
if message["role"] == "user":
content = message["content"]
def _clean_messages(self, request: ChatCompletionRequest):
for message in request.messages:
if message.role == "user":
content = message.content
if type(content) is list:
if len(content) > 0:
if type(content[0]) is dict:
if content[0].get("type", "") == "text":
if type(content[0].get("text", None)) is str:
message["content"] = content[0]["text"]

def _remove_system_prompt(self, request: dict):
if request["messages"][0]["role"] == "system":
request["messages"] = request["messages"][1:]
if isinstance(content[0],TextContent):
message.content = content[0].text

def _remove_system_prompt(self, request: ChatCompletionRequest):
if request.messages[0].role == "system":
request.messages = request.messages[1:]

def default(self):
return f"This is default page for the {self.title}"

@log_call(action_type="chat_completions", include_result=False)
def chat_completions(self, request: dict):


# @log_call(action_type="chat_completions", include_result=False)
async def chat_completions(self, request: ChatCompletionRequest) -> Union[ChatCompletionResponse, Any, ErrorResponse]:
try:
agent = self.agent
self._clean_messages(request)
self._remove_system_prompt(request)

if not request["messages"]:
log_message(
message_type="validation_error",
error="No messages provided in request"
)
return {
"error": {
"message": "No messages provided in request",
"type": "invalid_request_error",
"param": "messages",
"code": "invalid_request_error"
}
}, 400

# Validate required fields
if "model" not in request:
log_message(
message_type="validation_error",
error="model is required"
)
return {
"error": {
"message": "model is required",
"type": "invalid_request_error",
"param": "model",
"code": "invalid_request_error"
}
}, 400

is_streaming = request.get("stream", False)
stream_generator = agent.stream(request["messages"])
#Done by FastAPI+pydantic under the hood! Just supply schema...

# if not request.messages:
# log_message(
# message_type="validation_error",
# error="No messages provided in request"
# )
# return {
# "error": {
# "message": "No messages provided in request",
# "type": "invalid_request_error",
# "param": "messages",
# "code": "invalid_request_error"
# }
# }, 400
#
# # Validate required fields
# if "model" not in request:
# log_message(
# message_type="validation_error",
# error="model is required"
# )
# return {
# "error": {
# "message": "model is required",
# "type": "invalid_request_error",
# "param": "model",
# "code": "invalid_request_error"
# }
# }, 400

is_streaming = request.stream
messages = [message.model_dump(mode='json') for message in request.messages] # todo: support pydantic model!!!
stream_generator = agent.stream(
messages
)

if is_streaming:
return StreamingResponse(
stream_generator,
media_type="text/event-stream",
async_wrap(stream_generator),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
Expand All @@ -167,53 +176,56 @@ def chat_completions(self, request: dict):
# Collect all chunks into final response
response_content = ""
for chunk in stream_generator:

if chunk == "[DONE]":
break
try:
# Parse the SSE data
data = json.loads(chunk.decode().split("data: ")[1])
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
data = IAbstractStreamingProtocol.sse_parse(chunk)
json_data = data.get("data", "{}")
print(json_data)
if "choices" in json_data and len(json_data["choices"]) > 0:
delta = json_data["choices"][0].get("delta", {})
if "content" in delta:
response_content += delta["content"]
except Exception:
continue

return {
"id": f"chatcmpl-{time.time()}",
"object": "chat.completion",
"created": int(time.time()),
"model": request.get("model", "unknown"),
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response_content
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
return ChatCompletionResponse(
id=f"chatcmpl-{time.time()}",
object="chat.completion",
created=int(time.time()),
model=request.model,
choices=[ChatCompletionChoice(
index=0,
message=ResponseMessage(
role= "assistant",
content= response_content
),
finish_reason="stop"
)],
usage=ChatCompletionUsage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0
))

except Exception as e:
log_message(
message_type="chat_completion_error",
error=str(e),
error_type=type(e).__name__,
request_details={
"model": request.get("model"),
"message_count": len(request.get("messages", [])),
"streaming": request.get("stream", False)
}
# log_message(
# message_type="chat_completion_error",
# error=str(e),
# error_type=type(e).__name__,
# request_details={
# "model": request.model,
# "message_count": len(request.messages),
# "streaming": request.stream
# }
# )

error_response = ErrorResponse(
error=ErrorResponse.ErrorDetails(
message=str(e)
)
)
return {
"error": {
"message": str(e),
"type": "server_error",
"code": "internal_server_error"
}
}, 500
return error_response

Loading

0 comments on commit fd9164f

Please sign in to comment.