Skip to content

Commit

Permalink
[wip] Update
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 committed Jul 18, 2024
1 parent 2e270dc commit 8d72202
Showing 1 changed file with 33 additions and 4 deletions.
37 changes: 33 additions & 4 deletions scripts/generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from google.generativeai.types import HarmBlockThreshold, HarmCategory
from openai import OpenAI
from together import Together
from cohere import Client as CohereClient

ANTHROPIC_MODEL_LIST = (
"claude-1",
Expand Down Expand Up @@ -432,7 +433,7 @@ def chat_completion_anthropic(model, conv, temperature, max_tokens, api_dict=Non
if api_dict is not None and "api_key" in api_dict:
api_key = api_dict["api_key"]
else:
api_key = os.environ["ANTHROPIC_API_KEY"]
api_key = _get_api_key("ANTHROPIC_API_KEY")

sys_msg = ""
if conv.messages[0]["role"] == "system":
Expand Down Expand Up @@ -460,7 +461,8 @@ def chat_completion_anthropic(model, conv, temperature, max_tokens, api_dict=Non


def chat_completion_gemini(model, conv, temperature, max_tokens, api_dict=None):
genai.configure(api_key=os.environ["GEMINI_API_KEY"])
api_key = _get_api_key("GEMINI_API_KEY")
genai.configure(api_key=api_key)
api_model = genai.GenerativeModel(model)

for _ in range(API_MAX_RETRY):
Expand Down Expand Up @@ -515,7 +517,8 @@ def chat_completion_gemini(model, conv, temperature, max_tokens, api_dict=None):


def chat_completion_together(model, conv, temperature, max_tokens, api_dict=None):
client = Together(api_key=os.environ["TOGETHER_API_KEY"])
api_key = _get_api_key("TOGETHER_API_KEY")
client = Together(api_key=api_key)
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
Expand All @@ -537,7 +540,8 @@ def chat_completion_together(model, conv, temperature, max_tokens, api_dict=None


def chat_completion_openai(model, conv, temperature, max_tokens, api_dict=None):
client = OpenAI()
api_key = _get_api_key("OPENAI_API_KEY")
client = OpenAI(api_key=api_key)
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
Expand Down Expand Up @@ -567,3 +571,28 @@ def chat_completion_openai(model, conv, temperature, max_tokens, api_dict=None):
time.sleep(API_RETRY_SLEEP)

return output


def chat_completion_cohere(model, conv, temperature, max_tokens, api_dict=None):
api_key = _get_api_key("CO_API_KEY")
co = CohereClient(api_key=api_key)
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
# TODO
output = response.choices[0].message.content
break
# except any exception
except Exception as e:
print(f"Failed to connect to Together API: {e}")
time.sleep(API_RETRY_SLEEP)
return output
# TODO


def _get_api_key(key_name: str) -> Optional[str]:
api_key = os.getenv(key_name)
if not api_key:
raise ValueError(f"{key_name} not found!")
else:
return api_key

0 comments on commit 8d72202

Please sign in to comment.