From a60f7420bc4350a6b0f8454e13c18fb1fa73c498 Mon Sep 17 00:00:00 2001 From: Lj Miranda Date: Thu, 18 Jul 2024 10:09:44 -0700 Subject: [PATCH] [wip] Update --- scripts/generative.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/scripts/generative.py b/scripts/generative.py index 7f242e0..d226d66 100644 --- a/scripts/generative.py +++ b/scripts/generative.py @@ -432,7 +432,10 @@ def run_judge_pair( conv.set_system_message(system_prompt) judgment = chat_completion_together(model, conv, temperature=0, max_tokens=2048) elif model in COHERE_MODEL_LIST: - pass + conv = get_conv_template("raw") + conv.append_message(conv.roles[0], user_prompt) + conv.set_system_message(system_prompt) + judgment = chat_completion_cohere(model, conv, temperature=0, max_tokens=2048) else: raise ValueError(f"Model {model} not supported") @@ -593,23 +596,20 @@ def chat_completion_cohere(model, conv, temperature, max_tokens, api_dict=None): output = API_ERROR_OUTPUT for _ in range(API_MAX_RETRY): try: - # TODO: https://docs.cohere.com/reference/chat - co.chat( - message="", # TODO: message? - preamble="", # TODO: system prompt + response = co.chat( + message=conv.messages[0][1], + preamble=conv.system_message, model=model, - chat_history="", # TODO: chat history? temperature=temperature, max_tokens=max_tokens, ) - output = response.choices[0].message.content + output = response.get("text") break # except any exception except Exception as e: - print(f"Failed to connect to Together API: {e}") + print(f"Failed to connect to Cohere API: {e}") time.sleep(API_RETRY_SLEEP) return output - # TODO def _get_api_key(key_name: str) -> Optional[str]: