From de8626a9b8a23c7ef8bf32121eb1f0b713a77c07 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Fri, 20 Sep 2024 14:45:56 +0200 Subject: [PATCH] Fix parallelism in Thinker --- poetry.lock | 27 ++++++++++++- .../think_thoroughly_agent.py | 38 +++++++++++-------- pyproject.toml | 1 + scripts/replicator_stats.py | 7 ++-- 4 files changed, 54 insertions(+), 19 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0bd4cc6..407fb7a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1232,6 +1232,17 @@ azure = ["azure-storage-blob (>=12)", "azure-storage-file-datalake (>=12)"] gs = ["google-cloud-storage"] s3 = ["boto3 (>=1.34.0)"] +[[package]] +name = "cloudpickle" +version = "3.0.0" +description = "Pickler class to extend the standard pickle.Pickler functionality" +optional = false +python-versions = ">=3.8" +files = [ + {file = "cloudpickle-3.0.0-py3-none-any.whl", hash = "sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7"}, + {file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"}, +] + [[package]] name = "cohere" version = "5.9.3" @@ -4489,6 +4500,20 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} [package.extras] dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"] +[[package]] +name = "loky" +version = "3.4.1" +description = "A robust implementation of concurrent.futures.ProcessPoolExecutor" +optional = false +python-versions = ">=3.7" +files = [ + {file = "loky-3.4.1-py3-none-any.whl", hash = "sha256:7132da80d1a057b5917ff32c7867b65ed164aae84c259a1dbc44375791280c87"}, + {file = "loky-3.4.1.tar.gz", hash = "sha256:66db350de68c301299c882ace3b8f06ba5c4cb2c45f8fcffd498160ce8280753"}, +] + +[package.dependencies] +cloudpickle = "*" + [[package]] name = "lru-dict" version = "1.2.0" @@ -10910,4 +10935,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "~3.10.0" -content-hash = "bf1eb8858ba25625332d035dba93c1fbedb42b696ef7f96c796d0c71b8cb433a" +content-hash = "9a1896ba2d3d8efb4e8592a3d6fb5936202893902ae693675a7c2d1984424689" diff --git a/prediction_market_agent/agents/think_thoroughly_agent/think_thoroughly_agent.py b/prediction_market_agent/agents/think_thoroughly_agent/think_thoroughly_agent.py index e56a490..f68db47 100644 --- a/prediction_market_agent/agents/think_thoroughly_agent/think_thoroughly_agent.py +++ b/prediction_market_agent/agents/think_thoroughly_agent/think_thoroughly_agent.py @@ -5,8 +5,8 @@ import tenacity from crewai import Agent, Crew, Process, Task -from langchain.utilities.tavily_search import TavilySearchAPIWrapper from langchain_community.tools.tavily_search import TavilySearchResults +from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, @@ -14,6 +14,7 @@ from langchain_core.language_models import BaseChatModel from langchain_core.pydantic_v1 import SecretStr from langchain_openai import ChatOpenAI +from loky import get_reusable_executor from openai import APIError from prediction_market_agent_tooling.deploy.agent import initialize_langfuse from prediction_market_agent_tooling.loggers import logger @@ -22,11 +23,6 @@ OmenSubgraphHandler, ) from prediction_market_agent_tooling.tools.langfuse_ import langfuse_context, observe -from prediction_market_agent_tooling.tools.parallelism import ( - DEFAULT_PROCESSPOOL_EXECUTOR, - par_generator, - par_map, -) from prediction_market_agent_tooling.tools.tavily_storage.tavily_models import ( TavilyStorage, ) @@ -115,7 +111,14 @@ async def _arun( class ThinkThoroughlyBase(ABC): identifier: str - def __init__(self, model: str, enable_langfuse: bool, memory: bool = True) -> None: + def __init__( + self, + model: str, + enable_langfuse: bool, + memory: bool = True, + max_workers: int = 3, + worker_timeout: int = 5 * 60, + ) -> None: self.model = model self.enable_langfuse = enable_langfuse self.subgraph_handler = OmenSubgraphHandler() @@ -124,6 +127,8 @@ def __init__(self, model: str, enable_langfuse: bool, memory: bool = True) -> No self._long_term_memory = ( LongTermMemoryTableHandler(self.identifier) if self.memory else None ) + self.max_workers = max_workers + self.worker_timeout = worker_timeout disable_crewai_telemetry() # To prevent telemetry from being sent to CrewAI @@ -255,11 +260,13 @@ def get_correlated_markets(self, question: str) -> list[CorrelatedMarketInput]: 5, text=question ) - markets = par_map( - items=[q.market_address for q in nearest_questions], - func=lambda market_address: self.subgraph_handler.get_omen_market_by_market_id( + markets = get_reusable_executor( + max_workers=self.max_workers, timeout=self.worker_timeout + ).map( + lambda market_address: self.subgraph_handler.get_omen_market_by_market_id( market_id=market_address ), + [q.market_address for q in nearest_questions], ) return [CorrelatedMarketInput.from_omen_market(market) for market in markets] @@ -348,8 +355,12 @@ def answer_binary_market( f"Starting to generate predictions for each scenario, iteration {iteration + 1} / {n_iterations}." ) - sub_predictions = par_generator( - items=[ + sub_predictions = get_reusable_executor( + max_workers=self.max_workers, + timeout=self.worker_timeout, + ).map( + process_scenarios, + [ ( self.enable_langfuse, unique_id, @@ -364,8 +375,6 @@ def answer_binary_market( + conditional_scenarios.scenarios ) ], - func=process_scenarios, - executor=DEFAULT_PROCESSPOOL_EXECUTOR, ) scenarios_with_probs = [] @@ -577,7 +586,6 @@ def process_scenarios( ) -> tuple[str, AnswerWithScenario | None]: # Needs to be a normal function outside of class, because `lambda` and `self` aren't pickable for processpool executor, # and process pool executor is required, because ChromaDB isn't thread-safe. - # Input arguments needs to be as a single tuple, because par_generator requires a single argument. ( enable_langfuse, unique_id, diff --git a/pyproject.toml b/pyproject.toml index c802136..6cdaf57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ plotly = "^5.22.0" prediction-prophet = { git = "https://github.com/agentcoinorg/predictionprophet.git", rev = "93e052b37fa87573f5d06741ad86e184836977a0" } transformers = "^4.43.3" openfactverification-kongzii = "^0.2.0" +loky = "^3.4.1" [tool.poetry.group.dev.dependencies] langchain-chroma = "^0.1.2" diff --git a/scripts/replicator_stats.py b/scripts/replicator_stats.py index a19caf4..6005104 100644 --- a/scripts/replicator_stats.py +++ b/scripts/replicator_stats.py @@ -2,11 +2,11 @@ from pprint import pprint import typer +from loky import get_reusable_executor from prediction_market_agent_tooling.markets.omen.omen import OmenAgentMarket from prediction_market_agent_tooling.markets.omen.omen_subgraph_handler import ( OmenSubgraphHandler, ) -from prediction_market_agent_tooling.tools.parallelism import par_generator from tqdm import tqdm from prediction_market_agent.agents.replicate_to_omen_agent.deploy import ( @@ -20,17 +20,18 @@ def main() -> None: limit=None, creator=REPLICATOR_ADDRESS, ) + executor = get_reusable_executor(max_workers=5, timeout=5 * 60) bets_for_market = { market.id: bets for market, bets in tqdm( - par_generator( - markets, + executor.map( lambda m: ( m, OmenSubgraphHandler().get_bets( market_id=m.market_maker_contract_address_checksummed ), ), + markets, ), total=len(markets), desc="Loading bets",