Skip to content
This repository has been archived by the owner on Jan 5, 2025. It is now read-only.

Commit

Permalink
fix parameter position
Browse files Browse the repository at this point in the history
  • Loading branch information
gharbat committed Dec 12, 2023
1 parent bea0c85 commit 3d9cbe1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
openai_api_key = os.getenv("OPENAI_API_KEY")
logger = CustomLogger(module_name=__name__)

chat = get_chat_model(CHAT_MODELS.gpt_3_5_turbo_16k)
chat = get_chat_model()


def convert_json_to_text(
Expand Down
18 changes: 10 additions & 8 deletions llm-server/utils/get_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from typing import Dict
import os
from functools import lru_cache

from langchain.chat_models import ChatOllama, ChatAnthropic
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.base import BaseChatModel
from langchain.chat_models import ChatOllama, ChatAnthropic

from utils import llm_consts
from .chat_models import CHAT_MODELS
from functools import lru_cache
import os

localip = os.getenv("LOCAL_IP", "localhost")

model_name = os.getenv(llm_consts.model_env_var, CHAT_MODELS.gpt_4_32k)


@lru_cache(maxsize=1)
def get_chat_model() -> BaseChatModel:
if model_name == CHAT_MODELS.gpt_3_5_turbo:
model = ChatOpenAI(
model=CHAT_MODELS.gpt_3_5_turbo,
model=CHAT_MODELS.gpt_3_5_turbo,
temperature=0
)
elif model_name == CHAT_MODELS.gpt_4_32k:
Expand All @@ -30,12 +32,12 @@ def get_chat_model() -> BaseChatModel:
)
elif model_name == CHAT_MODELS.gpt_3_5_turbo_16k:
model = ChatOpenAI(
model=CHAT_MODELS.gpt_3_5_turbo_16k,
model=CHAT_MODELS.gpt_3_5_turbo_16k,
temperature=0
)
elif model_name == "claude":
model = ChatAnthropic(
anthropic_api_key=os.getenv("CLAUDE_API_KEY"),
anthropic_api_key=os.getenv("CLAUDE_API_KEY"),
)
elif model_name == "openchat":
model = ChatOllama(
Expand All @@ -45,4 +47,4 @@ def get_chat_model() -> BaseChatModel:
)
else:
raise ValueError(f"Unsupported model: {model_name}")
return model
return model

0 comments on commit 3d9cbe1

Please sign in to comment.