Skip to content

Commit

Permalink
Refactored providers into their own folder. Also added support Google…
Browse files Browse the repository at this point in the history
… Gemini models
  • Loading branch information
proy9714 committed Mar 23, 2024
1 parent 12df54e commit 9587418
Show file tree
Hide file tree
Showing 12 changed files with 519 additions and 64 deletions.
2 changes: 1 addition & 1 deletion config.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"selenium_web_browser": "chrome",
"search_api": "tavily",
"embedding_provider": "openai",
"llm_provider": "ChatOpenAI",
"llm_provider": "openai",
"fast_llm_model": "gpt-3.5-turbo-16k",
"smart_llm_model": "gpt-4",
"fast_token_limit": 2000,
Expand Down
4 changes: 2 additions & 2 deletions docs/docs/gpt-researcher/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Here is an example of the default config.py file found in `/gpt_researcher/confi
def __init__(self, config_file: str = None):
self.config_file = config_file
self.retriever = "tavily"
self.llm_provider = "ChatOpenAI"
self.llm_provider = "openai"
self.fast_llm_model = "gpt-3.5-turbo-16k"
self.smart_llm_model = "gpt-4-1106-preview"
self.fast_token_limit = 2000
Expand All @@ -42,7 +42,7 @@ def __init__(self, config_file: str = None):

Please note that you can also include your own external JSON file by adding the path in the `config_file` param.

To learn more about additional LLM support you can check out the [Langchain Adapter](https://python.langchain.com/docs/guides/adapters/openai) and [Langchain supported LLMs](https://python.langchain.com/docs/integrations/llms/) documentation. Simply pass different model names in the `llm_provider` config param.
To learn more about additional LLM support you can check out the [Langchain Adapter](https://python.langchain.com/docs/guides/adapters/openai) and [Langchain supported LLMs](https://python.langchain.com/docs/integrations/llms/) documentation. Simply pass different provider names in the `llm_provider` config param.

You can also change the search engine by modifying the `retriever` param to others such as `duckduckgo`, `googleAPI`, `googleSerp`, `searx` and more.

Expand Down
2 changes: 1 addition & 1 deletion gpt_researcher/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, config_file: str = None):
self.config_file = config_file if config_file else os.getenv('CONFIG_FILE')
self.retriever = os.getenv('SEARCH_RETRIEVER', "tavily")
self.embedding_provider = os.getenv('EMBEDDING_PROVIDER', 'openai')
self.llm_provider = os.getenv('LLM_PROVIDER', "ChatOpenAI")
self.llm_provider = os.getenv('LLM_PROVIDER', "openai")
self.fast_llm_model = os.getenv('FAST_LLM_MODEL', "gpt-3.5-turbo-16k")
self.smart_llm_model = os.getenv('SMART_LLM_MODEL', "gpt-4-1106-preview")
self.fast_token_limit = int(os.getenv('FAST_TOKEN_LIMIT', 2000))
Expand Down
7 changes: 7 additions & 0 deletions gpt_researcher/llm_provider/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .google.google import GoogleProvider
from .openai.openai import OpenAIProvider

__all__ = [
"GoogleProvider",
"OpenAIProvider"
]
Empty file.
103 changes: 103 additions & 0 deletions gpt_researcher/llm_provider/google/google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os

from colorama import Fore, Style
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_google_genai import ChatGoogleGenerativeAI


class GoogleProvider:

def __init__(
self,
model,
temperature,
max_tokens
):
# May be extended to support more google models in the future
self.model = "gemini-pro"
self.temperature = temperature
self.max_tokens = max_tokens
self.api_key = self.get_api_key()
self.llm = self.get_llm_model()

def get_api_key(self):
"""
Gets the GEMINI_API_KEY
Returns:
"""
try:
api_key = os.environ["GEMINI_API_KEY"]
except:
raise Exception(
"GEMINI API key not found. Please set the GEMINI_API_KEY environment variable.")
return api_key

def get_llm_model(self):
# Initializing the chat model
llm = ChatGoogleGenerativeAI(
convert_system_message_to_human=True,
model=self.model,
temperature=self.temperature,
max_output_tokens=self.max_tokens,
google_api_key=self.api_key
)

return llm

def convert_messages(self, messages):
"""
The function `convert_messages` converts messages based on their role into either SystemMessage
or HumanMessage objects.
Args:
messages: It looks like the code snippet you provided is a function called `convert_messages`
that takes a list of messages as input and converts each message based on its role into either a
`SystemMessage` or a `HumanMessage`.
Returns:
The `convert_messages` function is returning a list of converted messages based on the input
`messages`. The function checks the role of each message in the input list and creates a new
`SystemMessage` object if the role is "system" or a new `HumanMessage` object if the role is
"user". The function then returns a list of these converted messages.
"""
converted_messages = []
for message in messages:
if message["role"] == "system":
converted_messages.append(
SystemMessage(content=message["content"]))
elif message["role"] == "user":
converted_messages.append(
HumanMessage(content=message["content"]))

return converted_messages

async def get_chat_response(self, messages, stream, websocket=None):
if not stream:
# Getting output from the model chain using ainvoke for asynchronous invoking
converted_messages = self.convert_messages(messages)
output = await self.llm.ainvoke(converted_messages)

return output.content

else:
return await self.stream_response(messages, websocket)

async def stream_response(self, messages, websocket=None):
paragraph = ""
response = ""

# Streaming the response using the chain astream method from langchain
async for chunk in self.llm.astream(messages):
content = chunk.content
if content is not None:
response += content
paragraph += content
if "\n" in paragraph:
if websocket is not None:
await websocket.send_json({"type": "report", "output": paragraph})
else:
print(f"{Fore.GREEN}{paragraph}{Style.RESET_ALL}")
paragraph = ""

return response
Empty file.
72 changes: 72 additions & 0 deletions gpt_researcher/llm_provider/openai/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os

from colorama import Fore, Style
from langchain_openai import ChatOpenAI


class OpenAIProvider:

def __init__(
self,
model,
temperature,
max_tokens
):
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self.api_key = self.get_api_key()
self.llm = self.get_llm_model()

def get_api_key(self):
"""
Gets the OpenAI API key
Returns:
"""
try:
api_key = os.environ["OPENAI_API_KEY"]
except:
raise Exception(
"OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
return api_key

def get_llm_model(self):
# Initializing the chat model
llm = ChatOpenAI(
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
api_key=self.api_key
)

return llm

async def get_chat_response(self, messages, stream, websocket=None):
if not stream:
# Getting output from the model chain using ainvoke for asynchronous invoking
output = await self.llm.ainvoke(messages)

return output.content

else:
return await self.stream_response(messages, websocket)

async def stream_response(self, messages, websocket=None):
paragraph = ""
response = ""

# Streaming the response using the chain astream method from langchain
async for chunk in self.llm.astream(messages):
content = chunk.content
if content is not None:
response += content
paragraph += content
if "\n" in paragraph:
if websocket is not None:
await websocket.send_json({"type": "report", "output": paragraph})
else:
print(f"{Fore.GREEN}{paragraph}{Style.RESET_ALL}")
paragraph = ""

return response
91 changes: 33 additions & 58 deletions gpt_researcher/utils/llm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
# libraries
from __future__ import annotations
import logging

import json
from fastapi import WebSocket
from colorama import Fore, Style
from typing import Optional
from langchain_openai import ChatOpenAI

from colorama import Fore, Style
from fastapi import WebSocket

from gpt_researcher.master.prompts import auto_agent_instructions


def get_provider(llm_provider):
match llm_provider:
case "openai":
from ..llm_provider import OpenAIProvider
llm_provider = OpenAIProvider
case "google":
from ..llm_provider import GoogleProvider
llm_provider = GoogleProvider

case _:
raise Exception("LLM provider not found.")

return llm_provider


async def create_chat_completion(
messages: list, # type: ignore
model: Optional[str] = None,
Expand All @@ -34,71 +52,28 @@ async def create_chat_completion(
if model is None:
raise ValueError("Model cannot be None")
if max_tokens is not None and max_tokens > 8001:
raise ValueError(f"Max tokens cannot be more than 8001, but got {max_tokens}")
raise ValueError(
f"Max tokens cannot be more than 8001, but got {max_tokens}")

# Get the provider from supported providers
ProviderClass = get_provider(llm_provider)
provider = ProviderClass(
model,
temperature,
max_tokens
)

# create response
for _ in range(10): # maximum of 10 attempts
response = await send_chat_completion_request(
messages, model, temperature, max_tokens, stream, llm_provider, websocket
response = await provider.get_chat_response(
messages, stream, websocket
)
return response

logging.error("Failed to get response from OpenAI API")
raise RuntimeError("Failed to get response from OpenAI API")


import logging


async def send_chat_completion_request(
messages, model, temperature, max_tokens, stream, llm_provider, websocket=None
):
if not stream:
# Initializing the chat model
chat = ChatOpenAI(
model=model,
temperature=temperature,
max_tokens=max_tokens
)

# Getting output from the model chain using ainvoke for asynchronous invoking
output = await chat.ainvoke(messages)

return output.content

else:
return await stream_response(
model, messages, temperature, max_tokens, llm_provider, websocket
)


async def stream_response(model, messages, temperature, max_tokens, llm_provider, websocket=None):
# Initializing the model
chat = ChatOpenAI(
model=model,
temperature=temperature,
max_tokens=max_tokens
)

paragraph = ""
response = ""

# Streaming the response using the chain astream method from langchain
async for chunk in chat.astream(messages):
content = chunk.content
if content is not None:
response += content
paragraph += content
if "\n" in paragraph:
if websocket is not None:
await websocket.send_json({"type": "report", "output": paragraph})
else:
print(f"{Fore.GREEN}{paragraph}{Style.RESET_ALL}")
paragraph = ""

return response


def choose_agent(smart_llm_model: str, llm_provider: str, task: str) -> dict:
"""Determines what server should be used
Args:
Expand Down
Loading

0 comments on commit 9587418

Please sign in to comment.