forked from frdel/agent-zero
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
81 lines (58 loc) · 4.77 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI, OpenAI, OpenAIEmbeddings, AzureChatOpenAI, AzureOpenAIEmbeddings, AzureOpenAI
from langchain_community.llms.ollama import Ollama
from langchain_ollama import ChatOllama
from langchain_community.embeddings import OllamaEmbeddings
from langchain_anthropic import ChatAnthropic
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_google_genai import GoogleGenerativeAI, HarmBlockThreshold, HarmCategory
from pydantic.v1.types import SecretStr
# Load environment variables
load_dotenv()
# Configuration
DEFAULT_TEMPERATURE = 0.0
# Utility function to get API keys from environment variables
def get_api_key(service):
return os.getenv(f"API_KEY_{service.upper()}") or os.getenv(f"{service.upper()}_API_KEY")
# Ollama models
def get_ollama_chat(model_name:str, temperature=DEFAULT_TEMPERATURE, base_url=os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434", num_ctx=8192):
return ChatOllama(model=model_name,temperature=temperature, base_url=base_url, num_ctx=num_ctx)
def get_ollama_embedding(model_name:str, temperature=DEFAULT_TEMPERATURE, base_url=os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"):
return OllamaEmbeddings(model=model_name,temperature=temperature, base_url=base_url)
# HuggingFace models
def get_huggingface_embedding(model_name:str):
return HuggingFaceEmbeddings(model_name=model_name)
# LM Studio and other OpenAI compatible interfaces
def get_lmstudio_chat(model_name:str, temperature=DEFAULT_TEMPERATURE, base_url=os.getenv("LM_STUDIO_BASE_URL") or "http://127.0.0.1:1234/v1"):
return ChatOpenAI(model_name=model_name, base_url=base_url, temperature=temperature, api_key="none") # type: ignore
def get_lmstudio_embedding(model_name:str, base_url=os.getenv("LM_STUDIO_BASE_URL") or "http://127.0.0.1:1234/v1"):
return OpenAIEmbeddings(model=model_name, api_key="none", base_url=base_url, check_embedding_ctx_length=False) # type: ignore
# Anthropic models
def get_anthropic_chat(model_name:str, api_key=get_api_key("anthropic"), temperature=DEFAULT_TEMPERATURE):
return ChatAnthropic(model_name=model_name, temperature=temperature, api_key=api_key) # type: ignore
# OpenAI models
def get_openai_chat(model_name:str, api_key=get_api_key("openai"), temperature=DEFAULT_TEMPERATURE):
return ChatOpenAI(model_name=model_name, temperature=temperature, api_key=api_key) # type: ignore
def get_openai_instruct(model_name:str, api_key=get_api_key("openai"), temperature=DEFAULT_TEMPERATURE):
return OpenAI(model=model_name, temperature=temperature, api_key=api_key) # type: ignore
def get_openai_embedding(model_name:str, api_key=get_api_key("openai")):
return OpenAIEmbeddings(model=model_name, api_key=api_key) # type: ignore
def get_azure_openai_chat(deployment_name:str, api_key=get_api_key("openai_azure"), temperature=DEFAULT_TEMPERATURE, azure_endpoint=os.getenv("OPENAI_AZURE_ENDPOINT")):
return AzureChatOpenAI(deployment_name=deployment_name, temperature=temperature, api_key=api_key, azure_endpoint=azure_endpoint) # type: ignore
def get_azure_openai_instruct(deployment_name:str, api_key=get_api_key("openai_azure"), temperature=DEFAULT_TEMPERATURE, azure_endpoint=os.getenv("OPENAI_AZURE_ENDPOINT")):
return AzureOpenAI(deployment_name=deployment_name, temperature=temperature, api_key=api_key, azure_endpoint=azure_endpoint) # type: ignore
def get_azure_openai_embedding(deployment_name:str, api_key=get_api_key("openai_azure"), azure_endpoint=os.getenv("OPENAI_AZURE_ENDPOINT")):
return AzureOpenAIEmbeddings(deployment_name=deployment_name, api_key=api_key, azure_endpoint=azure_endpoint) # type: ignore
# Google models
def get_google_chat(model_name:str, api_key=get_api_key("google"), temperature=DEFAULT_TEMPERATURE):
return GoogleGenerativeAI(model=model_name, temperature=temperature, google_api_key=api_key, safety_settings={HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE }) # type: ignore
# Groq models
def get_groq_chat(model_name:str, api_key=get_api_key("groq"), temperature=DEFAULT_TEMPERATURE):
return ChatGroq(model_name=model_name, temperature=temperature, api_key=api_key) # type: ignore
# OpenRouter models
def get_openrouter_chat(model_name: str, api_key=get_api_key("openrouter"), temperature=DEFAULT_TEMPERATURE, base_url=os.getenv("OPEN_ROUTER_BASE_URL") or "https://openrouter.ai/api/v1"):
return ChatOpenAI(api_key=api_key, model=model_name, temperature=temperature, base_url=base_url) # type: ignore
def get_openrouter_embedding(model_name: str, api_key=get_api_key("openrouter"), base_url=os.getenv("OPEN_ROUTER_BASE_URL") or "https://openrouter.ai/api/v1"):
return OpenAIEmbeddings(model=model_name, api_key=api_key, base_url=base_url) # type: ignore