Skip to content

Commit

Permalink
fix: pre-commit lint
Browse files Browse the repository at this point in the history
  • Loading branch information
naaive committed Jul 24, 2024
1 parent 9a00529 commit a8c0cf7
Show file tree
Hide file tree
Showing 19 changed files with 87 additions and 198 deletions.
6 changes: 1 addition & 5 deletions src/openagent/agent/postgres_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@ def __init__(
@property
def messages(self) -> List[BaseMessage]: # type: ignore
with DBSession() as db_session:
histories = (
db_session.query(ChatHistory)
.filter(ChatHistory.session_id == self.session_id)
.all()
)
histories = db_session.query(ChatHistory).filter(ChatHistory.session_id == self.session_id).all()
lst = compose_left(
map(compose_left(lambda x: x.message, json.loads)),
filter(lambda x: x["type"] in ["ai", "human"]),
Expand Down
4 changes: 1 addition & 3 deletions src/openagent/agent/session_title.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ async def agen_session_title(user_id: str, session_id: str, history: str) -> lis
output = output.strip("'").strip('"')
logger.info(f"session title generated: {output}")
with DBSession() as db_session:
db_session.query(ChatSession).filter(
ChatSession.session_id == session_id
).update({ChatSession.title: output})
db_session.query(ChatSession).filter(ChatSession.session_id == session_id).update({ChatSession.title: output})
db_session.commit()
return output
8 changes: 2 additions & 6 deletions src/openagent/agent/stream_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def __init__(self) -> None:
self.queue = asyncio.Queue()
self.done = asyncio.Event()

async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
# If two calls are made in a row, this resets the state
self.done.clear()
self.current_llm_block_id = str(uuid.uuid4())
Expand Down Expand Up @@ -234,9 +232,7 @@ async def on_llm_error(
**kwargs: Any,
) -> None:
self.done.set()
return await super().on_llm_error(
error, run_id=run_id, parent_run_id=parent_run_id, tags=tags, **kwargs
)
return await super().on_llm_error(error, run_id=run_id, parent_run_id=parent_run_id, tags=tags, **kwargs)

# TODO implement the other methods

Expand Down
8 changes: 3 additions & 5 deletions src/openagent/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os

from chainlit.utils import mount_chainlit
from dotenv import load_dotenv
Expand Down Expand Up @@ -34,18 +35,15 @@ class Input(BaseModel):
text: str


@app.post(
"/api/stream_chat",
description="streaming chat api for openagent"
)
@app.post("/api/stream_chat", description="streaming chat api for openagent")
async def outline_creation(req: Input):
agent = get_agent("openagent")

async def stream():
async for event in agent.astream_events({"input": req.text}, version="v1"):
kind = event["event"]
if kind == "on_chat_model_stream":
yield json.dumps(event['data']['chunk'].dict(), ensure_ascii=False)
yield json.dumps(event["data"]["chunk"].dict(), ensure_ascii=False)

return EventSourceResponse(stream(), media_type="text/event-stream")

Expand Down
23 changes: 6 additions & 17 deletions src/openagent/dto/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@ class SessionTreeNodeDTOType(str, Enum):

class SessionTreeNodeDTO(BaseModel):
session_id: str = Field(description="session id")
parent_id: str | None = Field(
example=None, default=None, description="parent id, if null, is root folder"
)
parent_id: str | None = Field(example=None, default=None, description="parent id, if null, is root folder")
title: str | None = Field(default=None, description="session title")
order: int = Field(
description="order in parent folder, session will sort by order desc"
)
order: int = Field(description="order in parent folder, session will sort by order desc")
created_at: datetime = Field(description="create time")
children: list | None = []
type: SessionTreeNodeDTOType = SessionTreeNodeDTOType.folder
Expand Down Expand Up @@ -51,19 +47,15 @@ def build_session_tree_node(node: ChatSession) -> SessionTreeNodeDTO:
class NewSessionFolderDTO(BaseModel):
user_id: str = Field(example="jackma")
title: str = Field(example="folder1")
order: int = Field(
example=1, description="order in parent folder, session will sort by order desc"
)
order: int = Field(example=1, description="order in parent folder, session will sort by order desc")
parent_id: str | None = Field(
example=None,
default=None,
description="parent id, if null, will create root folder",
)

class Config:
json_schema_extra: ClassVar = {
"example": {"user_id": "jackma", "title": "folder1", "order": 0}
}
json_schema_extra: ClassVar = {"example": {"user_id": "jackma", "title": "folder1", "order": 0}}


class SessionTab(str, Enum):
Expand Down Expand Up @@ -100,14 +92,11 @@ class MoveSessionDTO(BaseModel):
user_id: str = Field(description="user id", example="jackma")
from_session_id: str = Field(description="source session id", example="1234567890")
to_session_tab: SessionTab = Field(
description="target tab, favorite or recent. "
"if recent, to_session_id will be ignored",
description="target tab, favorite or recent. " "if recent, to_session_id will be ignored",
example="favorite",
)
to_session_id: str | None = Field(
description="target parent session id, only valid when"
" to_session_tab is favorite, if null, "
"will move to root folder",
description="target parent session id, only valid when" " to_session_tab is favorite, if null, " "will move to root folder",
example="0987654321",
default=None,
)
4 changes: 1 addition & 3 deletions src/openagent/dto/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ class TransferQueryDTO(BaseModel):
example="0x4d2bf3A34a2311dB4b3D20D4719209EDaDBf69b6",
)
token: str = Field(description="token", example="ETH")
logoURI: str | None = Field(
description="logo uri", example="https://li.quest/logo.png"
)
logoURI: str | None = Field(description="logo uri", example="https://li.quest/logo.png")
decimals: int | None = Field(description="decimals", example=18)


Expand Down
21 changes: 8 additions & 13 deletions src/openagent/experts/feed_expert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from typing import Optional, Type

import aiohttp
Expand Down Expand Up @@ -28,24 +27,22 @@ class FeedExpert(BaseTool):
args_schema: Type[ParamSchema] = ParamSchema

def _run(
self,
address: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
self,
address: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError

async def _arun(
self,
address: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
self,
address: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
):
return await fetch_feeds(address)


async def fetch_feeds(address: str):
url = (
f"""{settings.RSS3_DATA_API}/decentralized/{address}?limit=5&action_limit=10"""
)
url = f"""{settings.RSS3_DATA_API}/decentralized/{address}?limit=5&action_limit=10"""
headers = {"Accept": "application/json"}
async with aiohttp.ClientSession() as session:
logger.info(f"fetching {url}")
Expand All @@ -62,9 +59,7 @@ async def fetch_feeds(address: str):
if "actions" in activity:
formatted_activity += "### Actions:\n"
for action in activity["actions"]:
formatted_activity += (
f"- {action['type']} from {action['from']} to {action['to']}\n"
)
formatted_activity += f"- {action['type']} from {action['from']} to {action['to']}\n"
if "metadata" in action:
for key, value in action["metadata"].items():
formatted_activity += f" - {key}: {value}\n"
Expand Down
10 changes: 2 additions & 8 deletions src/openagent/experts/nft_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@


class ARGS(BaseModel):
action: str = Field(
description="Specify the operation to perform: 'search' for NFT "
"collection search, 'rank' for collection ranking"
)
action: str = Field(description="Specify the operation to perform: 'search' for NFT " "collection search, 'rank' for collection ranking")
keyword: str = Field(
default="",
description="NFT symbol or collection name, required only for 'action=search'",
Expand Down Expand Up @@ -53,10 +50,7 @@ def _run(
elif action == "rank":
return self.collection_ranking(sort_field)
else:
return (
"Error: Unknown operation type. "
"Please specify 'action' as 'search' or 'rank'."
)
return "Error: Unknown operation type. " "Please specify 'action' as 'search' or 'rank'."

async def _arun(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/openagent/experts/search_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class SearchExpert(BaseTool):
description = """
A versatile search tool that can perform various types of searches based on the query type:
- For queries related to charts, data visualization, or dashboards, use Dune search.
- For queries about project introductions, current events or real-time information, use Google search.""" # noqa: E501
- For queries about project introductions, current events or real-time information, use Google search."""
args_schema: Type[SearchSchema] = SearchSchema

def _run(
Expand Down
61 changes: 24 additions & 37 deletions src/openagent/experts/swap_expert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Optional, Type, Literal
from typing import Literal, Optional, Type

from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
Expand All @@ -18,53 +18,43 @@ class ParamSchema(BaseModel):
"""
Schema for the parameters required for a token swap.
"""
from_token: str = Field(
description="Symbol of the token to swap from, e.g., 'BTC', 'ETH', 'RSS3', 'USDT', 'USDC'. Default: 'ETH'."
)
to_token: str = Field(
description="Symbol of the token to swap to, e.g., 'BTC', 'ETH', 'RSS3', 'USDT', 'USDC'. Default: 'ETH'."
)
from_chain: ChainLiteral = Field(
default="ETH",
description="Blockchain network to swap from. Default: 'ETH'."
)
to_chain: ChainLiteral = Field(
default="ETH",
description="Blockchain network to swap to. Default: 'ETH'."
)
amount: str = Field(
description="Amount of the from-side token to swap, e.g., '0.1', '1', '10'. Default: '1'."
)

from_token: str = Field(description="Symbol of the token to swap from, e.g., 'BTC', 'ETH', 'RSS3', 'USDT', 'USDC'. Default: 'ETH'.")
to_token: str = Field(description="Symbol of the token to swap to, e.g., 'BTC', 'ETH', 'RSS3', 'USDT', 'USDC'. Default: 'ETH'.")
from_chain: ChainLiteral = Field(default="ETH", description="Blockchain network to swap from. Default: 'ETH'.")
to_chain: ChainLiteral = Field(default="ETH", description="Blockchain network to swap to. Default: 'ETH'.")
amount: str = Field(description="Amount of the from-side token to swap, e.g., '0.1', '1', '10'. Default: '1'.")


class SwapExpert(BaseTool):
"""
Tool for generating a swap widget for cryptocurrency swaps.
"""

name = "swap"
description = "Use this tool to generate a swap widget for the user to swap cryptocurrencies."
args_schema: Type[ParamSchema] = ParamSchema
return_direct = False

def _run(
self,
from_token: str,
to_token: str,
from_chain: ChainLiteral,
to_chain: ChainLiteral,
amount: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
self,
from_token: str,
to_token: str,
from_chain: ChainLiteral,
to_chain: ChainLiteral,
amount: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError

async def _arun(
self,
from_token: str,
to_token: str,
from_chain: ChainLiteral = "ETH",
to_chain: ChainLiteral = "ETH",
amount: str = "1",
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
self,
from_token: str,
to_token: str,
from_chain: ChainLiteral = "ETH",
to_chain: ChainLiteral = "ETH",
amount: str = "1",
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
):
return await fetch_swap(from_token, to_token, from_chain, to_chain, amount)

Expand All @@ -87,10 +77,7 @@ async def fetch_swap(from_token: str, to_token: str, from_chain: ChainLiteral, t
to_chain_id = chain_name_to_id(to_chain)

# Fetch token data concurrently
from_token_data, to_token_data = await asyncio.gather(
select_best_token(from_token, from_chain_id),
select_best_token(to_token, to_chain_id)
)
from_token_data, to_token_data = await asyncio.gather(select_best_token(from_token, from_chain_id), select_best_token(to_token, to_chain_id))

swap = Swap(
from_token=get_token_data_by_key(from_token_data, "symbol"),
Expand All @@ -99,6 +86,6 @@ async def fetch_swap(from_token: str, to_token: str, from_chain: ChainLiteral, t
to_token_address=get_token_data_by_key(to_token_data, "address"),
from_chain_name=from_chain,
to_chain_name=to_chain,
amount=amount
amount=amount,
)
return swap.model_dump_json()
9 changes: 3 additions & 6 deletions src/openagent/experts/token_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Dict, List
from typing import Dict, List, Optional

import aiohttp
from aiocache import Cache
Expand Down Expand Up @@ -52,7 +52,7 @@ async def fetch_tokens() -> Dict[str, List[Dict]]:
headers = {"Accept": "application/json"}
logger.info(f"Fetching new data from {url}")

async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession() as session: # noqa
async with session.get(url, headers=headers) as response:
token_list = await response.json()
return token_list["tokens"]
Expand All @@ -75,10 +75,7 @@ async def select_best_token(keyword: str, chain_id: str) -> Optional[Dict]:
tokens_on_chain = tokens.get(chain_id, [])

# Filter based on symbol and name
results = [
token for token in tokens_on_chain
if token["symbol"].lower() == keyword or token["name"].lower() == keyword
]
results = [token for token in tokens_on_chain if token["symbol"].lower() == keyword or token["name"].lower() == keyword]

if results:
if len(results) == 1:
Expand Down
2 changes: 1 addition & 1 deletion src/openagent/experts/transfer_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pydantic import BaseModel, Field

from openagent.dto.mutation import Transfer
from openagent.experts.token_util import chain_name_to_id, select_best_token, get_token_data_by_key
from openagent.experts.token_util import chain_name_to_id, get_token_data_by_key, select_best_token


class ParamSchema(BaseModel):
Expand Down
6 changes: 2 additions & 4 deletions src/openagent/index/feed_scrape.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ def fetch_iqwiki_feeds(since_timestamp, until_timestamp, limit=10, cursor=None)
return fetch_feeds("IQ.Wiki", since_timestamp, until_timestamp, limit, cursor)


def fetch_feeds(
platform, since_timestamp, until_timestamp, limit=10, cursor=None, max_retries=3
) -> dict:
def fetch_feeds(platform, since_timestamp, until_timestamp, limit=10, cursor=None, max_retries=3) -> dict:
"""
Fetch feeds from a platform with retry functionality.
"""
Expand Down Expand Up @@ -55,4 +53,4 @@ def _fetch_feeds():

if __name__ == "__main__":
feeds = fetch_feeds("Mirror", 0, 0, 1, None, 3)
print(json.dumps(feeds,ensure_ascii=False))
print(json.dumps(feeds, ensure_ascii=False))
4 changes: 2 additions & 2 deletions src/openagent/router/onboarding.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def generate_stream():
tokens = re.findall(r"\S+\s*", introduction_text)

for token in tokens:
yield f'{{"message_id":"{unique_message_id}","block_id":null,"type":"natural_language","body":{json.dumps(token)}}}' # noqa: E501
yield f'{{"message_id":"{unique_message_id}","block_id":null,"type":"natural_language","body":{json.dumps(token)}}}'

questions_json = json.dumps(suggested_questions)
yield f'{{"message_id":"{unique_message_id}","block_id":null,"type":"suggested_questions","body":{questions_json}}}' # noqa: E501
yield f'{{"message_id":"{unique_message_id}","block_id":null,"type":"suggested_questions","body":{questions_json}}}'


@onboarding_router.post("/onboarding/", response_model=ChatResp)
Expand Down
Loading

0 comments on commit a8c0cf7

Please sign in to comment.