Skip to content

Commit

Permalink
Enhance chatbot routing with 'provider: model_name' format.
Browse files Browse the repository at this point in the history
  • Loading branch information
zh-plus committed May 17, 2024
1 parent 97253ad commit b345da5
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 20 deletions.
41 changes: 36 additions & 5 deletions openlrc/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import asyncio
import os
import random
import re
import time
from typing import List, Union, Dict, Callable

Expand All @@ -21,15 +22,46 @@
from openlrc.utils import get_messages_token_number, get_text_token_number

model2chatbot = {}
all_pricing = {
# Third-party provider models from https://api.g4f.icu/pricing
'mixtral-8x7b-32768': (0.25, 1),
'llama2-70b-4096': (0.25, 1.25),

# https://platform.deepseek.com/api-docs/pricing/
'deepseek-chat': (0.14, 0.28)
}


def _register_chatbot(cls):
all_pricing.update(cls.pricing)

for model in cls.pricing:
model2chatbot[model] = cls

return cls


def route_chatbot(model):
if ':' in model:
chatbot_type, chatbot_model = re.match(r'(.+):(.+)', model).groups()
chatbot_type, chatbot_model = chatbot_type.strip().lower(), chatbot_model.strip()

if chatbot_model not in all_pricing:
raise ValueError(f'Invalid model {chatbot_model}.')

if chatbot_type == 'openai':
return GPTBot, chatbot_model
elif chatbot_type == 'anthropic':
return ClaudeBot, chatbot_model
else:
raise ValueError(f'Invalid chatbot type {chatbot_type}.')

if model not in model2chatbot:
raise ValueError(f'Invalid model {model}.')

return model2chatbot[model], model


class ChatBot:
pricing = None

Expand All @@ -51,7 +83,7 @@ def model(self):

@model.setter
def model(self, model):
if model not in self.pricing:
if model not in all_pricing:
raise ValueError(f'Invalid model {model}.')
self._model = model

Expand All @@ -63,7 +95,7 @@ def estimate_fee(self, messages: List[Dict]):
for message in messages:
token_map[message['role']] += get_text_token_number(message['content'])

prompt_price, completion_price = self.pricing[self.model]
prompt_price, completion_price = all_pricing[self.model]

total_price = (sum(token_map.values()) * prompt_price + token_map['user'] * completion_price * 2) / 1000000

Expand Down Expand Up @@ -135,7 +167,6 @@ class GPTBot(ChatBot):
'gpt-4-turbo': (10, 30),
'gpt-4-turbo-2024-04-09': (10, 30),
'gpt-4o': (5, 15),
'deepseek-chat': (0.14, 0.28)
}

def __init__(self, model='gpt-3.5-turbo-0125', temperature=1, top_p=1, retry=8, max_async=16, json_mode=False,
Expand All @@ -161,7 +192,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.async_client.close()

def update_fee(self, response: ChatCompletion):
prompt_price, completion_price = self.pricing[self.model]
prompt_price, completion_price = all_pricing[self.model]

prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
Expand Down Expand Up @@ -240,7 +271,7 @@ def __init__(self, model='claude-3-sonnet-20240229', temperature=1, top_p=1, ret
self.fee_limit = fee_limit

def update_fee(self, response: Message):
prompt_price, completion_price = self.pricing[self.model]
prompt_price, completion_price = all_pricing[self.model]

prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens
Expand Down
27 changes: 13 additions & 14 deletions openlrc/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import requests

from openlrc.chatbot import model2chatbot
from openlrc.chatbot import route_chatbot, all_pricing
from openlrc.logger import logger
from openlrc.prompter import prompter_map, BaseTranslatePrompter, AtomicTranslatePrompter

Expand Down Expand Up @@ -43,20 +43,19 @@ def __init__(self, chatbot_model: str = 'gpt-3.5-turbo', prompter: str = 'base_t
if prompter not in prompter_map:
raise ValueError(f'Prompter {prompter} not found.')

if chatbot_model not in model2chatbot.keys():
raise ValueError(f'Chatbot {chatbot_model} not supported.')

self.temperature = 0.9

chatbot_category = model2chatbot[chatbot_model]
self.chatbot = chatbot_category(model=chatbot_model, fee_limit=fee_limit, proxy=proxy, retry=3,
temperature=self.temperature, base_url_config=base_url_config)
self.retry_chatbot = model2chatbot[retry_model](
model=retry_model, fee_limit=fee_limit,
proxy=proxy, retry=3,
temperature=self.temperature,
base_url_config=base_url_config
) if retry_model else None
chatbot_cls, model_name = route_chatbot(chatbot_model)
self.chatbot = chatbot_cls(model=model_name, fee_limit=fee_limit, proxy=proxy, retry=3,
temperature=self.temperature, base_url_config=base_url_config)

self.retry_chatbot = None
if retry_model:
retry_chatbot_cls, retry_model_name = route_chatbot(retry_model)
self.retry_chatbot = retry_chatbot_cls[retry_model](
model=retry_model_name, fee_limit=fee_limit, proxy=proxy, retry=3, temperature=self.temperature,
base_url_config=base_url_config
)

self.prompter = prompter
self.fee_limit = fee_limit
Expand All @@ -73,7 +72,7 @@ def list_chatbots():
Returns:
List[str]: List of available chatbot models.
"""
return list(model2chatbot.keys())
return list(all_pricing.keys())

@staticmethod
def make_chunks(texts, chunk_size=30):
Expand Down
24 changes: 23 additions & 1 deletion tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pydantic import BaseModel

from openlrc.chatbot import GPTBot, ClaudeBot
from openlrc.chatbot import GPTBot, ClaudeBot, route_chatbot


class Usage(BaseModel):
Expand Down Expand Up @@ -123,3 +123,25 @@ def test_claude_message_seq(self):
assert 'hello' in bot.get_content(results[0]).lower()

self.assertIn('hello', bot.get_content(results[0]).lower())

def test_route_chatbot(self):
chatbot_model1 = 'openai: claude-3-haiku-20240307'
chabot_cls1, model_name1 = route_chatbot(chatbot_model1)
self.assertEqual(chabot_cls1, GPTBot)
try:
_ = chabot_cls1(model=model_name1, temperature=1, top_p=1, retry=8, max_async=16)
except Exception as e:
self.fail(f"Failed to create chatbot model {chatbot_model1}: {e}")

chatbot_model2 = 'anthropic: gpt-3.5-turbo'
chabot_cls2, model_name2 = route_chatbot(chatbot_model2)
self.assertEqual(chabot_cls2, ClaudeBot)
try:
_ = chabot_cls2(model=model_name2, temperature=1, top_p=1, retry=8, max_async=16)
except Exception as e:
self.fail(f"Failed to create chatbot model {chatbot_model1}: {e}")

def test_route_chatbot_error(self):
chatbot_model = 'openai: invalid_model_name'
with self.assertRaises(ValueError):
route_chatbot(chatbot_model + 'error')

0 comments on commit b345da5

Please sign in to comment.