Skip to content

Commit

Permalink
Merge pull request #546 from baskaryan/bagatur/generic_llm
Browse files Browse the repository at this point in the history
RFC: generic llm provider
  • Loading branch information
assafelovic authored Jun 9, 2024
2 parents c29db81 + e627d6f commit 424f049
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 19 deletions.
6 changes: 4 additions & 2 deletions gpt_researcher/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ def __init__(self, config_file: str = None):
self.scraper = os.getenv("SCRAPER", "bs")
self.max_subtopics = os.getenv("MAX_SUBTOPICS", 3)
self.doc_path = os.getenv("DOC_PATH", "")

self.load_config_file()

if not hasattr(self, "llm_kwargs"):
self.llm_kwargs = {}

if self.doc_path:
self.validate_doc_path()

Expand Down
4 changes: 3 additions & 1 deletion gpt_researcher/llm_provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .anthropic.anthropic import AnthropicProvider
from .mistral.mistral import MistralProvider
from .huggingface.huggingface import HugginFaceProvider
from .generic import GenericLLMProvider

__all__ = [
"GoogleProvider",
Expand All @@ -17,5 +18,6 @@
"TogetherProvider",
"AnthropicProvider",
"MistralProvider",
"HugginFaceProvider"
"HugginFaceProvider",
"GenericLLMProvider",
]
3 changes: 3 additions & 0 deletions gpt_researcher/llm_provider/generic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import GenericLLMProvider

__all__ = ["GenericLLMProvider"]
147 changes: 147 additions & 0 deletions gpt_researcher/llm_provider/generic/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import importlib
from typing import Any

from colorama import Fore, Style


class GenericLLMProvider:

def __init__(self, llm):
self.llm = llm

@classmethod
def from_provider(cls, provider: str, **kwargs: Any):
if provider == "openai":
_check_pkg("langchain_openai")
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(**kwargs)
elif provider == "anthropic":
_check_pkg("langchain_anthropic")
from langchain_anthropic import ChatAnthropic

llm = ChatAnthropic(**kwargs)
elif provider == "azure_openai":
_check_pkg("langchain_openai")
from langchain_openai import AzureChatOpenAI

llm = AzureChatOpenAI(**kwargs)
elif provider == "cohere":
_check_pkg("langchain_cohere")
from langchain_cohere import ChatCohere

llm = ChatCohere(**kwargs)
elif provider == "google_vertexai":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai import ChatVertexAI

llm = ChatVertexAI(**kwargs)
elif provider == "google_genai":
_check_pkg("langchain_google_genai")
from langchain_google_genai import ChatGoogleGenerativeAI

llm = ChatGoogleGenerativeAI(**kwargs)
elif provider == "fireworks":
_check_pkg("langchain_fireworks")
from langchain_fireworks import ChatFireworks

llm = ChatFireworks(**kwargs)
elif provider == "ollama":
_check_pkg("langchain_community")
from langchain_community.chat_models import ChatOllama

llm = ChatOllama(**kwargs)
elif provider == "together":
_check_pkg("langchain_together")
from langchain_together import ChatTogether

llm = ChatTogether(**kwargs)
elif provider == "mistralai":
_check_pkg("langchain_mistralai")
from langchain_mistralai import ChatMistralAI

llm = ChatMistralAI(**kwargs)
elif provider == "huggingface":
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace

if "model" in kwargs or "model_name" in kwargs:
model_id = kwargs.pop("model", None) or kwargs.pop("model_name", None)
kwargs = {"model_id": model_id, **kwargs}
llm = ChatHuggingFace(**kwargs)
elif provider == "groq":
_check_pkg("langchain_groq")
from langchain_groq import ChatGroq

llm = ChatGroq(**kwargs)
elif provider == "bedrock":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrock

if "model" in kwargs or "model_name" in kwargs:
model_id = kwargs.pop("model", None) or kwargs.pop("model_name", None)
kwargs = {"model_id": model_id, **kwargs}
llm = ChatBedrock(**kwargs)
else:
supported = ", ".join(_SUPPORTED_PROVIDERS)
raise ValueError(
f"Unsupported {provider=}.\n\nSupported model providers are: "
f"{supported}"
)
return cls(llm)


async def get_chat_response(self, messages, stream, websocket=None):
if not stream:
# Getting output from the model chain using ainvoke for asynchronous invoking
output = await self.llm.ainvoke(messages)

return output.content

else:
return await self.stream_response(messages, websocket)

async def stream_response(self, messages, websocket=None):
paragraph = ""
response = ""

# Streaming the response using the chain astream method from langchain
async for chunk in self.llm.astream(messages):
content = chunk.content
if content is not None:
response += content
paragraph += content
if "\n" in paragraph:
if websocket is not None:
await websocket.send_json({"type": "report", "output": paragraph})
else:
print(f"{Fore.GREEN}{paragraph}{Style.RESET_ALL}")
paragraph = ""

return response



_SUPPORTED_PROVIDERS = {
"openai",
"anthropic",
"azure_openai",
"cohere",
"google_vertexai",
"google_genai",
"fireworks",
"ollama",
"together",
"mistralai",
"huggingface",
"groq",
"bedrock",
}

def _check_pkg(pkg: str) -> None:
if not importlib.util.find_spec(pkg):
pkg_kebab = pkg.replace("_", "-")
raise ImportError(
f"Unable to import {pkg_kebab}. Please install with "
f"`pip install -U {pkg_kebab}`"
)
5 changes: 5 additions & 0 deletions gpt_researcher/master/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ async def choose_agent(query, cfg, parent_query=None, cost_callback: callable =
{"role": "user", "content": f"task: {query}"}],
temperature=0,
llm_provider=cfg.llm_provider,
llm_kwargs=cfg.llm_kwargs,
cost_callback=cost_callback
)
agent_dict = json.loads(response)
Expand Down Expand Up @@ -103,6 +104,7 @@ async def get_sub_queries(query: str, agent_role_prompt: str, cfg, parent_query:
{"role": "user", "content": generate_search_queries_prompt(query, parent_query, report_type, max_iterations=max_research_iterations)}],
temperature=0,
llm_provider=cfg.llm_provider,
llm_kwargs=cfg.llm_kwargs,
cost_callback=cost_callback
)

Expand Down Expand Up @@ -206,6 +208,7 @@ async def summarize_url(query, raw_data, agent_role_prompt, cfg, cost_callback:
{"role": "user", "content": f"{generate_summary_prompt(query, raw_data)}"}],
temperature=0,
llm_provider=cfg.llm_provider,
llm_kwargs=cfg.llm_kwargs,
cost_callback=cost_callback
)
except Exception as e:
Expand Down Expand Up @@ -261,6 +264,7 @@ async def generate_report(
stream=True,
websocket=websocket,
max_tokens=cfg.smart_token_limit,
llm_kwargs=cfg.llm_kwargs,
cost_callback=cost_callback
)
except Exception as e:
Expand Down Expand Up @@ -298,6 +302,7 @@ async def get_report_introduction(query, context, role, config, websocket=None,
stream=True,
websocket=websocket,
max_tokens=config.smart_token_limit,
llm_kwargs=config.llm_kwargs,
cost_callback=cost_callback
)

Expand Down
25 changes: 9 additions & 16 deletions gpt_researcher/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import json
import logging
from typing import Optional
from typing import Optional, Any, Dict

from colorama import Fore, Style
from fastapi import WebSocket
Expand All @@ -15,7 +15,7 @@
from .validators import Subtopics


def get_provider(llm_provider):
def get_llm(llm_provider, **kwargs):
match llm_provider:
case "openai":
from ..llm_provider import OpenAIProvider
Expand Down Expand Up @@ -44,12 +44,12 @@ def get_provider(llm_provider):
case "anthropic":
from ..llm_provider import AnthropicProvider
llm_provider = AnthropicProvider
# Generic case for all other providers supported by Langchain
case _:
raise Exception("LLM provider not found. "
"Check here to learn more about support LLMs: "
"https://docs.gptr.dev/docs/gpt-researcher/llms")
from gpt_researcher.llm_provider import GenericLLMProvider
return GenericLLMProvider.from_provider(llm_provider, **kwargs)

return llm_provider
return llm_provider(**kwargs)


async def create_chat_completion(
Expand All @@ -60,6 +60,7 @@ async def create_chat_completion(
llm_provider: Optional[str] = None,
stream: Optional[bool] = False,
websocket: WebSocket | None = None,
llm_kwargs: Dict[str, Any] | None = None,
cost_callback: callable = None
) -> str:
"""Create a chat completion using the OpenAI API
Expand All @@ -84,12 +85,7 @@ async def create_chat_completion(
f"Max tokens cannot be more than 8001, but got {max_tokens}")

# Get the provider from supported providers
ProviderClass = get_provider(llm_provider)
provider = ProviderClass(
model,
temperature,
max_tokens
)
provider = get_llm(llm_provider, model=model, temperature=temperature, max_tokens=max_tokens, **llm_kwargs)

response = ""
# create response
Expand Down Expand Up @@ -123,10 +119,7 @@ async def construct_subtopics(task: str, data: str, config, subtopics: list = []

temperature = config.temperature
# temperature = 0 # Note: temperature throughout the code base is currently set to Zero
ProviderClass = get_provider(config.llm_provider)
provider = ProviderClass(model=config.smart_llm_model,
temperature=temperature,
max_tokens=config.smart_token_limit)
provider = get_llm(config.llm_provider, model=config.smart_llm_model, temperature=temperature, max_tokens=config.smart_token_limit, **config.llm_kwargs)
model = provider.llm


Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ langchain_anthropic>=0.1,<0.2
langchain_mistralai>=0.1,<0.2
langchain_huggingface>=0.0.1,<0.1
langchain_together>=0.1,<0.2
langchain_cohere
langchain_google_vertexai
langchain_fireworks
langchain_aws
tiktoken
tavily-python
arxiv
Expand Down

0 comments on commit 424f049

Please sign in to comment.