diff --git a/poetry.lock b/poetry.lock index 327bfd4..a783fc6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -2674,6 +2674,21 @@ profiling = ["gprof2dot"] rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] +[[package]] +name = "markdownify" +version = "0.11.6" +description = "Convert HTML to markdown." +optional = false +python-versions = "*" +files = [ + {file = "markdownify-0.11.6-py3-none-any.whl", hash = "sha256:ba35fe289d5e9073bcd7d2cad629278fe25f1a93741fcdc0bfb4f009076d8324"}, + {file = "markdownify-0.11.6.tar.gz", hash = "sha256:009b240e0c9f4c8eaf1d085625dcd4011e12f0f8cec55dedf9ea6f7655e49bfe"}, +] + +[package.dependencies] +beautifulsoup4 = ">=4.9,<5" +six = ">=1.15,<2" + [[package]] name = "markupsafe" version = "2.1.5" @@ -4969,6 +4984,21 @@ files = [ [package.extras] widechars = ["wcwidth"] +[[package]] +name = "tavily-python" +version = "0.3.1" +description = "Python wrapper for the Tavily API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "tavily-python-0.3.1.tar.gz", hash = "sha256:0c1bbc89560f145d7aa1355555ed850f909be2745f7556ea47759bef5ccf8b7c"}, + {file = "tavily_python-0.3.1-py3-none-any.whl", hash = "sha256:34616e6b3a86a3a043b6387d059e3652a7b9a8ac82ade865c047343391dee658"}, +] + +[package.dependencies] +requests = "*" +tiktoken = "0.5.2" + [[package]] name = "tenacity" version = "8.2.3" @@ -5722,4 +5752,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "~3.10.0" -content-hash = "eff920cc3ab56424cc12c209b31f6bf93102770186b5e7c4c5c3aad467c0225e" +content-hash = "cb59b50824d38e37378e1575d1c49a895408abbd5727a7d778dc12ae1a9e75c2" diff --git a/prediction_market_agent/agents/autogen_agent.py b/prediction_market_agent/agents/autogen_agent.py index 5d61d38..f734053 100644 --- a/prediction_market_agent/agents/autogen_agent.py +++ b/prediction_market_agent/agents/autogen_agent.py @@ -7,8 +7,8 @@ from prediction_market_agent import utils from prediction_market_agent.agents.abstract import AbstractAgent -from prediction_market_agent.tools.google_search import GoogleSearchTool -from prediction_market_agent.tools.web_scrape import WebScrapingTool +from prediction_market_agent.tools.web_scrape.basic_summary import WebScrapingTool +from prediction_market_agent.tools.web_search.google import GoogleSearchTool class AutoGenAgent(AbstractAgent): diff --git a/prediction_market_agent/agents/custom_agent.py b/prediction_market_agent/agents/custom_agent.py index 69dea90..91bc03e 100644 --- a/prediction_market_agent/agents/custom_agent.py +++ b/prediction_market_agent/agents/custom_agent.py @@ -16,11 +16,11 @@ ) from prediction_market_agent.ai_models.llama_ai_models import ChatReplicateLLamaModel from prediction_market_agent.ai_models.openai_ai_models import ChatOpenAIModel -from prediction_market_agent.tools.google_search import google_search from prediction_market_agent.tools.tool_exception_handler import tool_exception_handler -from prediction_market_agent.tools.web_scrape_structured import ( +from prediction_market_agent.tools.web_scrape.structured_summary import ( web_scrape_structured_and_summarized, ) +from prediction_market_agent.tools.web_search.google import google_search class CustomAgent(AbstractAgent): diff --git a/prediction_market_agent/agents/known_outcome_agent/benchmark.py b/prediction_market_agent/agents/known_outcome_agent/benchmark.py new file mode 100644 index 0000000..3849d71 --- /dev/null +++ b/prediction_market_agent/agents/known_outcome_agent/benchmark.py @@ -0,0 +1,163 @@ +import time +import typing as t +from datetime import timedelta + +from dotenv import load_dotenv +from prediction_market_agent_tooling.benchmark.agents import AbstractBenchmarkedAgent +from prediction_market_agent_tooling.benchmark.benchmark import Benchmarker +from prediction_market_agent_tooling.benchmark.utils import ( + Market, + MarketSource, + OutcomePrediction, + Prediction, +) +from prediction_market_agent_tooling.tools.utils import utcnow +from pydantic import BaseModel + +from prediction_market_agent.agents.known_outcome_agent.known_outcome_agent import ( + Result, + get_known_outcome, +) + + +class QuestionWithKnownOutcome(BaseModel): + url: t.Optional[str] = None + question: str + result: Result + notes: t.Optional[str] = None + + def to_market(self) -> Market: + dt = utcnow() + return Market( + url=self.url if self.url else "", + question=self.question, + source=MarketSource.MANIFOLD, + p_yes=self.result.to_p_yes() if self.result != Result.UNKNOWN else 0.5, + volume=0.0, + created_time=dt, + close_time=dt, + ) + + +class KnownOutcomeAgent(AbstractBenchmarkedAgent): + def __init__( + self, + agent_name: str, + max_workers: int, + model: str, + max_tries: int, + ) -> None: + self.model = model + self.max_tries = max_tries + super().__init__(agent_name=agent_name, max_workers=max_workers) + + def predict(self, market_question: str) -> Prediction: + answer = get_known_outcome( + model=self.model, + question=market_question, + max_tries=self.max_tries, + ) + if answer.result == Result.UNKNOWN: + return Prediction( + is_predictable=False, + outcome_prediction=None, + ) + else: + return Prediction( + is_predictable=True, + outcome_prediction=OutcomePrediction( + p_yes=answer.result.to_p_yes(), + confidence=1.0, + info_utility=None, + ), + ) + + +if __name__ == "__main__": + load_dotenv() + tomorrow_str = (utcnow() + timedelta(days=1)).strftime("%d %B %Y") + + # Fetch questions from existing markets, or make some up, where the + # outcome is known. + qs_with_known_outcome: list[QuestionWithKnownOutcome] = [ + QuestionWithKnownOutcome( + question=f"Will 'Barbie' win an Academy Award for best original song by {tomorrow_str}?", + url="https://aiomen.eth.limo/#/0xceb2a4ecc217cab440acf60737a9fcfd6d3fbf4b", + result=Result.YES, + notes="Happened on 10th March 2024.", + ), + QuestionWithKnownOutcome( + question=f"Will the 2024 Oscars winner for Best Picture be announced by {tomorrow_str}?", + url="https://aiomen.eth.limo/#/0xb88e4507709148e096bcdfb861b17db7b4d54e6b", + result=Result.YES, + notes="Happened on 10th March 2024.", + ), + QuestionWithKnownOutcome( + question=f"Will Liverpool win against Atalanta in the Europa League quarter-finals by {tomorrow_str}?", + url="https://aiomen.eth.limo/#/0x1d5a462c801360b4bebbda2b9656e52801a27a3b", + result=Result.NO, + notes="The match is scheduled for 11 April 2024.", + ), + QuestionWithKnownOutcome( + question=f"Will Donald Trump officially become the GOP nominee for the 2024 presidential elections by {tomorrow_str}?", + url="https://aiomen.eth.limo/#/0x859a6b465ee1e4a73aab0f2da4428c6255da466c", + result=Result.YES, + notes="Happened on 10th March 2024.", + ), + QuestionWithKnownOutcome( + question=f"Will SpaceX successfully test a Starship reentry without losing contact by {tomorrow_str}?", + url="https://aiomen.eth.limo/#/0xcc9123af8db309e0c60c63f9e2b8b82fc86f458b", + result=Result.NO, + notes="The only scheduled test flight occured, and contact was lost during the test.", + ), + QuestionWithKnownOutcome( + question=f"Will Arsenal reach the Champions League semi-finals on {tomorrow_str}?", + url="https://aiomen.eth.limo/#/0x606efd175b245cd60282a98cef402d4f5e950f92", + result=Result.NO, + notes="They are scheduled to play the first leg of the quarter-finals on 9 April 2024.", + ), + QuestionWithKnownOutcome( + question=f"Will the jury deliver a verdict on James Crumbley's 'bad parenting' case on {tomorrow_str}?", + url="https://aiomen.eth.limo/#/0xe55171beda0d60fd45092ff8bf93d5cb566a2510", + result=Result.NO, + notes="The verdict was announced on 15th March 2024.", + ), + QuestionWithKnownOutcome( + question="Will Lewis Hamilton win the 2024/2025 F1 drivers champtionship?", + result=Result.UNKNOWN, + notes="Outcome is uncertain.", + ), + QuestionWithKnownOutcome( + question="Will the cost of grain in the Spain increase by 20% by 19 July 2024?", + result=Result.UNKNOWN, + notes="Outcome is uncertain.", + ), + QuestionWithKnownOutcome( + question="Will over 360 pople have died while climbing Mount Everest by 1st Jan 2028?", + result=Result.UNKNOWN, + notes="Outcome is uncertain.", + ), + ] + + benchmarker = Benchmarker( + markets=[q.to_market() for q in qs_with_known_outcome], + agents=[ + KnownOutcomeAgent( + agent_name="known_outcome", + model="gpt-4-1106-preview", + max_tries=3, + max_workers=1, + ), + ], + ) + benchmarker.run_agents() + md = benchmarker.generate_markdown_report() + + output = f"./known_outcome_agent_benchmark_report.{int(time.time())}.md" + with open(output, "w") as f: + print(f"Writing benchmark report to: {output}") + f.write(md) + + # Check all predictions are correct, i.e. mean-squared-error == 0 + metrics = benchmarker.compute_metrics() + assert metrics["MSE for `p_yes`"][0] == 0.0 diff --git a/prediction_market_agent/agents/known_outcome_agent/deploy.py b/prediction_market_agent/agents/known_outcome_agent/deploy.py new file mode 100644 index 0000000..bb7ff33 --- /dev/null +++ b/prediction_market_agent/agents/known_outcome_agent/deploy.py @@ -0,0 +1,83 @@ +import getpass +from decimal import Decimal + +from prediction_market_agent_tooling.config import APIKeys +from prediction_market_agent_tooling.deploy.agent import DeployableAgent +from prediction_market_agent_tooling.deploy.constants import OWNER_KEY +from prediction_market_agent_tooling.gtypes import SecretStr, private_key_type +from prediction_market_agent_tooling.markets.agent_market import AgentMarket +from prediction_market_agent_tooling.markets.data_models import BetAmount, Currency +from prediction_market_agent_tooling.markets.markets import MarketType +from prediction_market_agent_tooling.tools.utils import ( + get_current_git_commit_sha, + get_current_git_url, +) +from prediction_market_agent_tooling.tools.web3_utils import verify_address + +from prediction_market_agent.agents.known_outcome_agent.known_outcome_agent import ( + Result, + get_known_outcome, +) + + +def market_is_saturated(market: AgentMarket) -> bool: + return market.p_yes > 0.95 or market.p_no > 0.95 + + +class DeployableKnownOutcomeAgent(DeployableAgent): + model = "gpt-4-1106-preview" + + def load(self) -> None: + self.markets_with_known_outcomes: dict[str, Result] = {} + + def pick_markets(self, markets: list[AgentMarket]) -> list[AgentMarket]: + picked_markets: list[AgentMarket] = [] + for market in markets: + # Assume very high probability markets are already known, and have + # been correctly bet on, and therefore the value of betting on them + # is low. + if not market_is_saturated(market=market): + answer = get_known_outcome( + model=self.model, + question=market.question, + max_tries=3, + ) + if answer.has_known_outcome(): + picked_markets.append(market) + self.markets_with_known_outcomes[market.id] = answer.result + + return picked_markets + + def answer_binary_market(self, market: AgentMarket) -> bool: + # The answer has already been determined in `pick_markets` so we just + # return it here. + return self.markets_with_known_outcomes[market.id].to_boolean() + + def calculate_bet_amount(self, answer: bool, market: AgentMarket) -> BetAmount: + if market.currency == Currency.xDai: + return BetAmount(amount=Decimal(0.1), currency=Currency.xDai) + else: + raise NotImplementedError("This agent only supports xDai markets") + + +if __name__ == "__main__": + agent = DeployableKnownOutcomeAgent() + agent.deploy_gcp( + repository=f"git+{get_current_git_url()}.git@{get_current_git_commit_sha()}", + market_type=MarketType.OMEN, + labels={OWNER_KEY: getpass.getuser()}, + secrets={ + "TAVILY_API_KEY": "GNOSIS_AI_TAVILY_API_KEY:latest", + }, + memory=1024, + api_keys=APIKeys( + BET_FROM_ADDRESS=verify_address( + "0xb611A9f02B318339049264c7a66ac3401281cc3c" + ), + BET_FROM_PRIVATE_KEY=private_key_type("EVAN_OMEN_BETTER_0_PKEY:latest"), + OPENAI_API_KEY=SecretStr("EVAN_OPENAI_API_KEY:latest"), + MANIFOLD_API_KEY=None, + ), + cron_schedule="0 */4 * * *", + timeout=540, + ) diff --git a/prediction_market_agent/agents/known_outcome_agent/known_outcome_agent.py b/prediction_market_agent/agents/known_outcome_agent/known_outcome_agent.py new file mode 100644 index 0000000..983d5bf --- /dev/null +++ b/prediction_market_agent/agents/known_outcome_agent/known_outcome_agent.py @@ -0,0 +1,185 @@ +import json +import typing as t +from datetime import datetime +from enum import Enum + +from langchain.prompts import ChatPromptTemplate +from langchain_openai import ChatOpenAI +from pydantic import BaseModel + +from prediction_market_agent.tools.web_scrape.basic_summary import _summary +from prediction_market_agent.tools.web_scrape.markdown import web_scrape +from prediction_market_agent.tools.web_search.tavily import web_search + + +class Result(str, Enum): + YES = "YES" + NO = "NO" + UNKNOWN = "UNKNOWN" + + def to_p_yes(self) -> float: + if self is Result.YES: + return 1.0 + elif self is Result.NO: + return 0.0 + else: + raise ValueError("Unexpected result") + + def to_boolean(self) -> bool: + if self is Result.YES: + return True + elif self is Result.NO: + return False + else: + raise ValueError("Unexpected result") + + +class Answer(BaseModel): + result: Result + reasoning: str + + def has_known_outcome(self) -> bool: + return self.result is not Result.UNKNOWN + + +GENERATE_SEARCH_QUERY_PROMPT = """ +The current date is {date_str}. You are trying to determine whether the answer +to the following question has a definite answer. Generate a web search query +based on the question to find relevant information. + +"{question}" + +For example, if the question is: + +"Will Arsenal reach the Champions League semi-finals on 19 March 2025?" + +You might generate the following search query: + +"Champions League semi-finals draw 2024" + +Answer with the single prompt only, and nothing else. +""" + + +ANSWER_FROM_WEBSCRAPE_PROMPT = """ +The current date is {date_str}. You are an expert researcher trying to answer a +question based on information scraped from the web. The question is: + +``` +{question} +``` + +The information you have scraped from the web is: + +``` +{scraped_content} +``` + +You goal is to determine whether the answer can be inferred with a reasonable +degree of certainty from the scraped web content. Answer in json format with the +following fields: +- "result": "", +- "reasoning": "", + +where is a free text field containing your reasoning, and +is a multiple-choice field containing only one of 'YES' (if the answer to the +question is yes), 'NO' (if the answer to the question is no), or 'UNKNOWN' if +you are unable to answer the question with a reasonable degree of certainty from +the web-scraped information. Your answer should only contain this json string, +and nothing else. + +If the question is of the format: "Will X happen by Y?" then the result should +be as follows: +- If X has already happened, the result is 'YES'. +- If not-X has already happened, the result is 'NO'. +- If X has been announced to happen after Y, result 'NO'. +- Otherwise, the result is 'UNKNOWN'. + +If the question is of the format: "Will X happen on Y?" +- If something has happened that necessarily prevents X from happening on Y, the result is 'NO'. +- Otherwise, the result is 'UNKNOWN'. +""" + + +def completion_str_to_json(completion: str) -> dict[str, t.Any]: + """ + Cleans completion JSON in form of a string: + + ```json + { + ... + } + ``` + + into just { ... } + ``` + """ + start_index = completion.find("{") + end_index = completion.rfind("}") + completion = completion[start_index : end_index + 1] + completion_dict: dict[str, t.Any] = json.loads(completion) + return completion_dict + + +def summarize_if_required(content: str, model: str, question: str) -> str: + """ + If the content is too long to fit in the model's context, summarize it. + """ + if model == "gpt-3.5-turbo-0125": # 16k context length + max_length = 10000 + elif model == "gpt-4-1106-preview": # 128k context length + max_length = 100000 + else: + raise ValueError(f"Unknown model: {model}") + + if len(content) > max_length: + return _summary(content=content, objective=question, separators=[" "]) + else: + return content + + +def get_known_outcome(model: str, question: str, max_tries: int) -> Answer: + """ + In a loop, perform web search and scrape to find if the answer to the + question is known. Break if the answer is found, or after a certain number + of tries, and no definite answer is found, return an 'unknown' answer. + """ + tries = 0 + date_str = datetime.now().strftime("%d %B %Y") + previous_urls = [] + llm = ChatOpenAI(model=model, temperature=0.4) + while tries < max_tries: + search_prompt = ChatPromptTemplate.from_template( + template=GENERATE_SEARCH_QUERY_PROMPT + ).format_messages(date_str=date_str, question=question) + search_query = str(llm.invoke(search_prompt).content).strip('"') + search_results = web_search(query=search_query, max_results=5) + if not search_results: + raise ValueError("No search results found.") + + for result in search_results: + if result.url in previous_urls: + continue + previous_urls.append(result.url) + + scraped_content = web_scrape(url=result.url) + scraped_content = summarize_if_required( + content=scraped_content, model=model, question=question + ) + + prompt = ChatPromptTemplate.from_template( + template=ANSWER_FROM_WEBSCRAPE_PROMPT + ).format_messages( + date_str=date_str, + question=question, + scraped_content=scraped_content, + ) + answer = str(llm.invoke(prompt).content) + parsed_answer = Answer.model_validate(completion_str_to_json(answer)) + + if parsed_answer.result is not Result.UNKNOWN: + return parsed_answer + + tries += 1 + + return Answer(result=Result.UNKNOWN, reasoning="Max tries exceeded.") diff --git a/prediction_market_agent/agents/llamaindex_agent.py b/prediction_market_agent/agents/llamaindex_agent.py index 6cce632..2f33379 100644 --- a/prediction_market_agent/agents/llamaindex_agent.py +++ b/prediction_market_agent/agents/llamaindex_agent.py @@ -5,8 +5,8 @@ from prediction_market_agent import utils from prediction_market_agent.agents.abstract import AbstractAgent -from prediction_market_agent.tools.google_search import google_search -from prediction_market_agent.tools.web_scrape import web_scrape +from prediction_market_agent.tools.web_scrape.basic_summary import web_scrape +from prediction_market_agent.tools.web_search.google import google_search class LlamaIndexAgent(AbstractAgent): diff --git a/prediction_market_agent/tools/web_scrape/__init__.py b/prediction_market_agent/tools/web_scrape/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/prediction_market_agent/tools/web_scrape.py b/prediction_market_agent/tools/web_scrape/basic_summary.py similarity index 89% rename from prediction_market_agent/tools/web_scrape.py rename to prediction_market_agent/tools/web_scrape/basic_summary.py index 621d527..a7d932f 100644 --- a/prediction_market_agent/tools/web_scrape.py +++ b/prediction_market_agent/tools/web_scrape/basic_summary.py @@ -6,10 +6,12 @@ from langchain_openai import ChatOpenAI -def _summary(objective: str, content: str) -> str: - llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-16k-0613") +def _summary( + objective: str, content: str, separators: list[str] = ["\n\n", "\n"] +) -> str: + llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0125") text_splitter = RecursiveCharacterTextSplitter( - separators=["\n\n", "\n"], chunk_size=10000, chunk_overlap=500 + separators=separators, chunk_size=10000, chunk_overlap=500 ) docs = text_splitter.create_documents([content]) map_prompt = ( diff --git a/prediction_market_agent/tools/web_scrape/markdown.py b/prediction_market_agent/tools/web_scrape/markdown.py new file mode 100644 index 0000000..6a7dbd7 --- /dev/null +++ b/prediction_market_agent/tools/web_scrape/markdown.py @@ -0,0 +1,53 @@ +import requests +import tenacity +from bs4 import BeautifulSoup +from markdownify import markdownify +from prediction_market_agent_tooling.tools.cache import persistent_inmemory_cache +from requests import Response + + +@tenacity.retry( + stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_fixed(1), reraise=True +) +@persistent_inmemory_cache +def fetch_html(url: str, timeout: int) -> Response: + headers = { + "User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:107.0) Gecko/20100101 Firefox/107.0" + } + response = requests.get(url, headers=headers, timeout=timeout) + return response + + +def web_scrape(url: str, timeout: int = 10) -> str: + """ + Taken from polywrap/predictionprophet + + https://github.com/polywrap/predictionprophet/blob/97aeea8f87e9b42da242d00d93ed5754bd64f21e/prediction_prophet/functions/web_scrape.py + """ + try: + response = fetch_html(url=url, timeout=timeout) + + if "text/html" in response.headers.get("Content-Type", ""): + soup = BeautifulSoup(response.content, "html.parser") + + [x.extract() for x in soup.findAll("script")] + [x.extract() for x in soup.findAll("style")] + [x.extract() for x in soup.findAll("noscript")] + [x.extract() for x in soup.findAll("link")] + [x.extract() for x in soup.findAll("head")] + [x.extract() for x in soup.findAll("image")] + [x.extract() for x in soup.findAll("img")] + + text: str = soup.get_text() + text = markdownify(text) + text = " ".join([x.strip() for x in text.split("\n")]) + text = " ".join([x.strip() for x in text.split(" ")]) + + return text + else: + print("Non-HTML content received") + return "" + + except requests.RequestException as e: + print(f"HTTP request failed: {e}") + return "" diff --git a/prediction_market_agent/tools/web_scrape_structured.py b/prediction_market_agent/tools/web_scrape/structured_summary.py similarity index 97% rename from prediction_market_agent/tools/web_scrape_structured.py rename to prediction_market_agent/tools/web_scrape/structured_summary.py index 4145c99..f233d7c 100644 --- a/prediction_market_agent/tools/web_scrape_structured.py +++ b/prediction_market_agent/tools/web_scrape/structured_summary.py @@ -1,7 +1,7 @@ import requests from bs4 import BeautifulSoup, Comment, Tag -from prediction_market_agent.tools.web_scrape import _summary +from prediction_market_agent.tools.web_scrape.basic_summary import _summary def web_scrape_structured_and_summarized( diff --git a/prediction_market_agent/tools/web_search/__init__.py b/prediction_market_agent/tools/web_search/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/prediction_market_agent/tools/google_search.py b/prediction_market_agent/tools/web_search/google.py similarity index 100% rename from prediction_market_agent/tools/google_search.py rename to prediction_market_agent/tools/web_search/google.py diff --git a/prediction_market_agent/tools/web_search/tavily.py b/prediction_market_agent/tools/web_search/tavily.py new file mode 100644 index 0000000..5485f27 --- /dev/null +++ b/prediction_market_agent/tools/web_search/tavily.py @@ -0,0 +1,41 @@ +import tenacity +from prediction_market_agent_tooling.tools.cache import persistent_inmemory_cache +from prediction_market_agent_tooling.tools.utils import ( + check_not_none, + secret_str_from_env, +) +from pydantic import BaseModel +from tavily import TavilyClient + + +class WebSearchResult(BaseModel): + url: str + query: str + + +@tenacity.retry( + stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_fixed(1), reraise=True +) +@persistent_inmemory_cache +def web_search(query: str, max_results: int) -> list[WebSearchResult]: + """ + Web search using Tavily API. + """ + tavily_api_key = check_not_none(secret_str_from_env("TAVILY_API_KEY")) + tavily = TavilyClient(api_key=tavily_api_key.get_secret_value()) + response = tavily.search( + query=query, + search_depth="advanced", + max_results=max_results, + include_raw_content=True, + ) + + results = [ + WebSearchResult( + url=result["url"], + query=query, + ) + for result in response["results"] + ] + + return results diff --git a/pyproject.toml b/pyproject.toml index cd5ea82..4b758f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,8 @@ prediction-market-agent-tooling = "^0.5.0" pydantic-settings = "^2.1.0" autoflake = "^2.2.1" isort = "^5.13.2" +markdownify = "^0.11.6" +tavily-python = "^0.3.1" [build-system] requires = ["poetry-core"]