Skip to content

Commit

Permalink
chore: Various code cleanup. Logo does not display from local URL.
Browse files Browse the repository at this point in the history
  • Loading branch information
anirbanbasu committed Aug 17, 2024
1 parent cfdc348 commit 1c3dc31
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 65 deletions.
68 changes: 31 additions & 37 deletions src/coder_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,13 @@ class CoderOutput(BaseModel):
code: str = Field(..., description="Python code implementation for the solution.")


class MultiAgentOrchestrator:
def __init__(self, llm: BaseChatModel, prompt: ChatPromptTemplate):
class MultiAgentDirectedGraph:
def __init__(self, llm: BaseChatModel, solver_prompt: ChatPromptTemplate):
self._llm = llm
draft_solver_messages = [
("system", constants.ENV_VAR_VALUE__LLM_CODER_SYSTEM_PROMPT),
("human", "{input}"),
]
self._draft_solver_prompt = ChatPromptTemplate.from_messages(
messages=draft_solver_messages
)
self._runnable_draft_solver = (
self._draft_solver_prompt | self._llm.with_structured_output(CoderOutput)
self._solver_agent = solver_prompt | self._llm.with_structured_output(
CoderOutput
)
# solver_messages = [
# ("system", constants.ENV_VAR_VALUE__LLM_CODER_SYSTEM_PROMPT),
# ("human", "{input}"),
# ]
# self._solver_prompt = ChatPromptTemplate.from_messages(messages=solver_messages)
self._runnable_solver = prompt | self._llm.with_structured_output(CoderOutput)
# This is NOT an agent and it should not be.
self._evaluator = CodeExecutor()
# self._retriever = BM25Retriever.from_texts(
# [self.format_example(row) for row in train_ds]
Expand Down Expand Up @@ -115,7 +103,7 @@ def format_test_cases(self, test_cases: list[TestCase]) -> str:
"""
test_cases_str = "\n".join(
[
f"<test id={i}>\n{test_case}\n</test>"
f"<test id='{i}'>\n{test_case}\n</test>"
for i, test_case in enumerate(test_cases)
]
)
Expand Down Expand Up @@ -180,24 +168,18 @@ def solve(self, state: AgentState) -> dict:
else constants.AGENT_STATE__KEY_MESSAGES
)
if has_examples:
# output_key = constants.AGENT_STATE__KEY_MESSAGES
# Retrieve examples to solve the problem
inputs[constants.AGENT_STATE__KEY_EXAMPLES] = state[
constants.AGENT_STATE__KEY_EXAMPLES
]
else:
inputs[constants.AGENT_STATE__KEY_EXAMPLES] = constants.EMPTY_STRING
response = self.pydantic_to_ai_message(
# Use the draft solver only if the `draft` flag is set in the state
self._runnable_draft_solver.invoke(inputs)
if state[constants.AGENT_STATE__KEY_DRAFT] is True
else self._runnable_solver.invoke(inputs)
)
response = self.pydantic_to_ai_message(self._solver_agent.invoke(inputs))
ic(response)
return (
{
output_key: [response],
constants.AGENT_STATE__KEY_DRAFT: (False),
constants.AGENT_STATE__KEY_DRAFT: False,
}
if state[constants.AGENT_STATE__KEY_DRAFT]
else {output_key: [response]}
Expand Down Expand Up @@ -244,7 +226,7 @@ def evaluate(self, state: AgentState) -> dict:
"""
ic(f"State in evaluate: {state}")
test_cases = state[constants.AGENT_STATE__KEY_TEST_CASES]
# Extract the `AIMessage` that is expected to contain the code from the last call.
# Extract the `AIMessage` that is expected to contain the code from the last call to the solver that was NOT to generate a candidate solution.
ai_message: AIMessage = state[constants.AGENT_STATE__KEY_MESSAGES][-1]
# ai_message is a list of dictionaries.
json_dict = ai_message.content[0]
Expand All @@ -268,7 +250,9 @@ def evaluate(self, state: AgentState) -> dict:
return {
constants.AGENT_STATE__KEY_MESSAGES: [
# self.format_tool_message(repr(e), ai_message)
FunctionMessage(content=repr(e), name="evaluate")
FunctionMessage(
content=repr(e), name=constants.AGENT_STATE_GRAPH_NODE__EVALUATE
)
]
}
num_test_cases = len(test_cases) if test_cases is not None else 0
Expand Down Expand Up @@ -298,28 +282,33 @@ def evaluate(self, state: AgentState) -> dict:
}

responses = "\n".join(
[f"<test id={i}>\n{r}\n</test>" for i, r in enumerate(test_results)]
[f"<test id='{i}'>\n{r}\n</test>" for i, r in enumerate(test_results)]
)
response = f"Incorrect submission. Please review the failures reported below and respond with updated code.\nPass rate: {pass_rate}\nResults:\n{responses}"

return {
constants.AGENT_STATE__KEY_MESSAGES: [
# self.format_tool_message(response, ai_message)
FunctionMessage(content=response, name="evaluate")
FunctionMessage(
content=response, name=constants.AGENT_STATE_GRAPH_NODE__EVALUATE
)
]
}

def build_agent_graph(self):
builder = StateGraph(AgentState)
# builder.add_node("draft", self.draft_solve)
# builder.add_node("retrieve", self.retrieve_examples)
builder.add_node("solve", self.solve)
builder.add_node("evaluate", self.evaluate)
builder.add_node(constants.AGENT_STATE_GRAPH_NODE__SOLVE, self.solve)
builder.add_node(constants.AGENT_STATE_GRAPH_NODE__EVALUATE, self.evaluate)
# Add connectivity
builder.add_edge(START, "solve")
builder.add_edge(START, constants.AGENT_STATE_GRAPH_NODE__SOLVE)
# builder.add_edge("draft", "retrieve")
# builder.add_edge("retrieve", "solve")
builder.add_edge("solve", "evaluate")
builder.add_edge(
constants.AGENT_STATE_GRAPH_NODE__SOLVE,
constants.AGENT_STATE_GRAPH_NODE__EVALUATE,
)

def control_edge(state: AgentState):
if (
Expand All @@ -329,12 +318,17 @@ def control_edge(state: AgentState):
== constants.AGENT_NODE__EVALUATE_STATUS_NO_TEST_CASES
):
return END
return "solve"
return constants.AGENT_STATE_GRAPH_NODE__SOLVE

builder.add_conditional_edges(
"evaluate", control_edge, {END: END, "solve": "solve"}
constants.AGENT_STATE_GRAPH_NODE__EVALUATE,
control_edge,
{
END: END,
constants.AGENT_STATE_GRAPH_NODE__SOLVE: constants.AGENT_STATE_GRAPH_NODE__SOLVE,
},
)
builder.add_edge("solve", END)
builder.add_edge(constants.AGENT_STATE_GRAPH_NODE__SOLVE, END)
connection = sqlite3.connect(":memory:", check_same_thread=False)
checkpointer = SqliteSaver(conn=connection)
self.agent_graph = builder.compile(checkpointer=checkpointer, debug=False)
31 changes: 18 additions & 13 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
PROJECT_GIT_REPO_URL = "https://github.com/anirbanbasu/chatty-coder"
PROJECT_GIT_REPO_LABEL = "GitHub repository"
PROJECT_NAME = "chatty-coder"
PROJECT_LOGO_PATH = "assets/logo-embed.png"
PROJECT_LOGO_PATH = "assets/logo-embed.svg"

HTTP_TARGET_BLANK = "_blank"

Expand All @@ -46,6 +46,12 @@
EXECUTOR_MESSAGE__WRONG_ANSWER = "Wrong answer"
EXECUTOR_MESSAGE__NO_RESULTS = "No result returned"

AGENT_STATE_GRAPH_NODE__DRAFT_SOLVE = "draft_solve"
AGENT_STATE_GRAPH_NODE__DRAFT_REVIEW = "draft_review"
AGENT_STATE_GRAPH_NODE__SOLVE = "solve"
AGENT_STATE_GRAPH_NODE__EVALUATE = "evaluate"


AGENT_STATE__KEY_CANDIDATE = "candidate"
AGENT_STATE__KEY_EXAMPLES = "examples"
AGENT_STATE__KEY_MESSAGES = "messages"
Expand Down Expand Up @@ -121,9 +127,9 @@
Please respond with a Python 3 solution to the given problem below.
First, output a reasoning through the problem and conceptualise a solution. Whenever possible, add a time and a space complexity analysis for your solution.
Then, output a pseudocode in Pascal to implement your concept solution.
Finally, a well-documented working Python 3 code for your solution. Do not use external libraries. Your code must be able to accept inputs from `sys.stdin` and write the final output to `sys.stdout` (or, to `sys.stderr` in case of errors).
First, output a `reasoning` through the problem and conceptualise a solution. Whenever possible, add a time and a space complexity analysis for your solution.
Then, output a `pseudocode` in Pascal to implement your concept solution.
Finally, a well-documented working Python 3 `code` for your solution. Do not use external libraries. Your code must be able to accept inputs from `sys.stdin` and write the final output to `sys.stdout` (or, to `sys.stderr` in case of errors).
Optional examples of similar problems and solutions (may not be in Python):
{examples}
Expand All @@ -136,9 +142,9 @@
Please respond with a Python 3 solution to the given problem below.
First, output a reasoning through the problem and conceptualise a solution. Whenever possible, add a time and a space complexity analysis for your solution.
Then, output a pseudocode in Pascal to implement your concept solution.
Finally, a well-documented working Python 3 code for your solution. Do not use external libraries. Your code must be able to accept inputs from `sys.stdin` and write the final output to `sys.stdout` (or, to `sys.stderr` in case of errors).
First, output a `reasoning` through the problem and conceptualise a solution. Whenever possible, add a time and a space complexity analysis for your solution.
Then, output a `pseudocode` in Pascal to implement your concept solution.
Finally, a well-documented working Python 3 `code` for your solution. Do not use external libraries. Your code must be able to accept inputs from `sys.stdin` and write the final output to `sys.stdout` (or, to `sys.stderr` in case of errors).
Optional examples of similar problems and solutions (may not be in Python):
{examples}
Expand All @@ -158,9 +164,8 @@
"""

JS__DARK_MODE_TOGGLE = """
() => {
document.body.classList.toggle('dark');
// document.querySelector('gradio-app').style.backgroundColor = 'var(--color-background-primary)';
document.querySelector('gradio-app').style.background = 'var(--body-background-fill)';
}
"""
() => {
document.body.classList.toggle('dark');
document.querySelector('gradio-app').style.background = 'var(--body-background-fill)';
}
"""
32 changes: 17 additions & 15 deletions src/webapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@

from dotenv import load_dotenv
import uuid
import os

from coder_agent import MultiAgentOrchestrator, TestCase
from coder_agent import MultiAgentDirectedGraph, TestCase
from utils import parse_env

try:
Expand Down Expand Up @@ -226,7 +225,7 @@ def find_solution(
),
],
)
coder_agent = MultiAgentOrchestrator(llm=self._llm, prompt=prompt)
coder_agent = MultiAgentDirectedGraph(llm=self._llm, solver_prompt=prompt)
coder_agent.build_agent_graph()
config = {"configurable": {"thread_id": uuid.uuid4().hex, "k": 3}}
graph_input = {
Expand All @@ -240,10 +239,10 @@ def find_solution(
config=config,
)
for result in result_iterator:
if "solve" in result:
coder_output: AIMessage = result["solve"][
constants.AGENT_STATE__KEY_MESSAGES
][-1]
if constants.AGENT_STATE_GRAPH_NODE__SOLVE in result:
coder_output: AIMessage = result[
constants.AGENT_STATE_GRAPH_NODE__SOLVE
][constants.AGENT_STATE__KEY_MESSAGES][-1]
if coder_output:
json_dict = coder_output.content[0]
yield [
Expand Down Expand Up @@ -295,7 +294,6 @@ def delete_test_case(

def construct_interface(self):
"""Construct the Gradio user interface and make it available through the `interface` property of this class."""
gr.set_static_paths(paths=[constants.PROJECT_LOGO_PATH])
with gr.Blocks(
# See theming guide at https://www.gradio.app/guides/theming-guide
# theme="gstaff/xkcd",
Expand All @@ -306,19 +304,22 @@ def construct_interface(self):
fill_height=True,
css=constants.CSS__GRADIO_APP,
analytics_enabled=False,
# Delete the cache content every day that is older than a day
delete_cache=(86400, 86400),
) as self.interface:
gr.set_static_paths(paths=[constants.PROJECT_LOGO_PATH])
with gr.Row(elem_id="ui_header", equal_height=True):
with gr.Column(scale=10):
gr.HTML(
"""
f"""
<img
width="384"
height="96"
style="filter: invert(0.5);"
alt="chatty-coder logo"
src="https://raw.githubusercontent.com/anirbanbasu/chatty-coder/master/assets/logo-embed.svg" />
src="/file={constants.PROJECT_LOGO_PATH}" />
""",
# "/file=assets/logo-embed.svg"
# "/file=assets/logo-embed.svg" or "file/assets/logo-embed.svg"?
# "https://raw.githubusercontent.com/anirbanbasu/chatty-coder/master/assets/logo-embed.svg"
)
with gr.Column(scale=3):
Expand Down Expand Up @@ -455,16 +456,17 @@ def construct_interface(self):
def run(self):
"""Run the Gradio app by launching a server."""
self.construct_interface()
allowed_paths = [
f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}/{constants.PROJECT_LOGO_PATH}"
allowed_static_file_paths = [
constants.PROJECT_LOGO_PATH,
]
ic(allowed_paths)
ic(allowed_static_file_paths)
self.interface.queue().launch(
server_name=self._gradio_host,
server_port=self._gradio_port,
show_api=True,
show_error=True,
allowed_paths=allowed_paths,
allowed_paths=allowed_static_file_paths,
enable_monitoring=True,
)


Expand Down

0 comments on commit 1c3dc31

Please sign in to comment.