-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #546 from baskaryan/bagatur/generic_llm
RFC: generic llm provider
- Loading branch information
Showing
7 changed files
with
175 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base import GenericLLMProvider | ||
|
||
__all__ = ["GenericLLMProvider"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import importlib | ||
from typing import Any | ||
|
||
from colorama import Fore, Style | ||
|
||
|
||
class GenericLLMProvider: | ||
|
||
def __init__(self, llm): | ||
self.llm = llm | ||
|
||
@classmethod | ||
def from_provider(cls, provider: str, **kwargs: Any): | ||
if provider == "openai": | ||
_check_pkg("langchain_openai") | ||
from langchain_openai import ChatOpenAI | ||
|
||
llm = ChatOpenAI(**kwargs) | ||
elif provider == "anthropic": | ||
_check_pkg("langchain_anthropic") | ||
from langchain_anthropic import ChatAnthropic | ||
|
||
llm = ChatAnthropic(**kwargs) | ||
elif provider == "azure_openai": | ||
_check_pkg("langchain_openai") | ||
from langchain_openai import AzureChatOpenAI | ||
|
||
llm = AzureChatOpenAI(**kwargs) | ||
elif provider == "cohere": | ||
_check_pkg("langchain_cohere") | ||
from langchain_cohere import ChatCohere | ||
|
||
llm = ChatCohere(**kwargs) | ||
elif provider == "google_vertexai": | ||
_check_pkg("langchain_google_vertexai") | ||
from langchain_google_vertexai import ChatVertexAI | ||
|
||
llm = ChatVertexAI(**kwargs) | ||
elif provider == "google_genai": | ||
_check_pkg("langchain_google_genai") | ||
from langchain_google_genai import ChatGoogleGenerativeAI | ||
|
||
llm = ChatGoogleGenerativeAI(**kwargs) | ||
elif provider == "fireworks": | ||
_check_pkg("langchain_fireworks") | ||
from langchain_fireworks import ChatFireworks | ||
|
||
llm = ChatFireworks(**kwargs) | ||
elif provider == "ollama": | ||
_check_pkg("langchain_community") | ||
from langchain_community.chat_models import ChatOllama | ||
|
||
llm = ChatOllama(**kwargs) | ||
elif provider == "together": | ||
_check_pkg("langchain_together") | ||
from langchain_together import ChatTogether | ||
|
||
llm = ChatTogether(**kwargs) | ||
elif provider == "mistralai": | ||
_check_pkg("langchain_mistralai") | ||
from langchain_mistralai import ChatMistralAI | ||
|
||
llm = ChatMistralAI(**kwargs) | ||
elif provider == "huggingface": | ||
_check_pkg("langchain_huggingface") | ||
from langchain_huggingface import ChatHuggingFace | ||
|
||
if "model" in kwargs or "model_name" in kwargs: | ||
model_id = kwargs.pop("model", None) or kwargs.pop("model_name", None) | ||
kwargs = {"model_id": model_id, **kwargs} | ||
llm = ChatHuggingFace(**kwargs) | ||
elif provider == "groq": | ||
_check_pkg("langchain_groq") | ||
from langchain_groq import ChatGroq | ||
|
||
llm = ChatGroq(**kwargs) | ||
elif provider == "bedrock": | ||
_check_pkg("langchain_aws") | ||
from langchain_aws import ChatBedrock | ||
|
||
if "model" in kwargs or "model_name" in kwargs: | ||
model_id = kwargs.pop("model", None) or kwargs.pop("model_name", None) | ||
kwargs = {"model_id": model_id, **kwargs} | ||
llm = ChatBedrock(**kwargs) | ||
else: | ||
supported = ", ".join(_SUPPORTED_PROVIDERS) | ||
raise ValueError( | ||
f"Unsupported {provider=}.\n\nSupported model providers are: " | ||
f"{supported}" | ||
) | ||
return cls(llm) | ||
|
||
|
||
async def get_chat_response(self, messages, stream, websocket=None): | ||
if not stream: | ||
# Getting output from the model chain using ainvoke for asynchronous invoking | ||
output = await self.llm.ainvoke(messages) | ||
|
||
return output.content | ||
|
||
else: | ||
return await self.stream_response(messages, websocket) | ||
|
||
async def stream_response(self, messages, websocket=None): | ||
paragraph = "" | ||
response = "" | ||
|
||
# Streaming the response using the chain astream method from langchain | ||
async for chunk in self.llm.astream(messages): | ||
content = chunk.content | ||
if content is not None: | ||
response += content | ||
paragraph += content | ||
if "\n" in paragraph: | ||
if websocket is not None: | ||
await websocket.send_json({"type": "report", "output": paragraph}) | ||
else: | ||
print(f"{Fore.GREEN}{paragraph}{Style.RESET_ALL}") | ||
paragraph = "" | ||
|
||
return response | ||
|
||
|
||
|
||
_SUPPORTED_PROVIDERS = { | ||
"openai", | ||
"anthropic", | ||
"azure_openai", | ||
"cohere", | ||
"google_vertexai", | ||
"google_genai", | ||
"fireworks", | ||
"ollama", | ||
"together", | ||
"mistralai", | ||
"huggingface", | ||
"groq", | ||
"bedrock", | ||
} | ||
|
||
def _check_pkg(pkg: str) -> None: | ||
if not importlib.util.find_spec(pkg): | ||
pkg_kebab = pkg.replace("_", "-") | ||
raise ImportError( | ||
f"Unable to import {pkg_kebab}. Please install with " | ||
f"`pip install -U {pkg_kebab}`" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters