diff --git a/scripts/generative.py b/scripts/generative.py index 74d1395..60d1898 100644 --- a/scripts/generative.py +++ b/scripts/generative.py @@ -29,6 +29,7 @@ from google.generativeai.types import HarmBlockThreshold, HarmCategory from openai import OpenAI from together import Together +from cohere import Client as CohereClient ANTHROPIC_MODEL_LIST = ( "claude-1", @@ -432,7 +433,7 @@ def chat_completion_anthropic(model, conv, temperature, max_tokens, api_dict=Non if api_dict is not None and "api_key" in api_dict: api_key = api_dict["api_key"] else: - api_key = os.environ["ANTHROPIC_API_KEY"] + api_key = _get_api_key("ANTHROPIC_API_KEY") sys_msg = "" if conv.messages[0]["role"] == "system": @@ -460,7 +461,8 @@ def chat_completion_anthropic(model, conv, temperature, max_tokens, api_dict=Non def chat_completion_gemini(model, conv, temperature, max_tokens, api_dict=None): - genai.configure(api_key=os.environ["GEMINI_API_KEY"]) + api_key = _get_api_key("GEMINI_API_KEY") + genai.configure(api_key=api_key) api_model = genai.GenerativeModel(model) for _ in range(API_MAX_RETRY): @@ -515,7 +517,8 @@ def chat_completion_gemini(model, conv, temperature, max_tokens, api_dict=None): def chat_completion_together(model, conv, temperature, max_tokens, api_dict=None): - client = Together(api_key=os.environ["TOGETHER_API_KEY"]) + api_key = _get_api_key("TOGETHER_API_KEY") + client = Together(api_key=api_key) output = API_ERROR_OUTPUT for _ in range(API_MAX_RETRY): try: @@ -537,7 +540,8 @@ def chat_completion_together(model, conv, temperature, max_tokens, api_dict=None def chat_completion_openai(model, conv, temperature, max_tokens, api_dict=None): - client = OpenAI() + api_key = _get_api_key("OPENAI_API_KEY") + client = OpenAI(api_key=api_key) output = API_ERROR_OUTPUT for _ in range(API_MAX_RETRY): try: @@ -567,3 +571,28 @@ def chat_completion_openai(model, conv, temperature, max_tokens, api_dict=None): time.sleep(API_RETRY_SLEEP) return output + + +def chat_completion_cohere(model, conv, temperature, max_tokens, api_dict=None): + api_key = _get_api_key("CO_API_KEY") + co = CohereClient(api_key=api_key) + output = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + # TODO + output = response.choices[0].message.content + break + # except any exception + except Exception as e: + print(f"Failed to connect to Together API: {e}") + time.sleep(API_RETRY_SLEEP) + return output + # TODO + + +def _get_api_key(key_name: str) -> Optional[str]: + api_key = os.getenv(key_name) + if not api_key: + raise ValueError(f"{key_name} not found!") + else: + return api_key