Skip to content

Commit

Permalink
Enable OpenAI model streaming and fix num_tokens_from_messages (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
zechengz authored Jun 30, 2023
1 parent f4de3ff commit e92705f
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 87 deletions.
143 changes: 121 additions & 22 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from types import GeneratorType
from typing import Any, Dict, List, Optional, Tuple

from tenacity import retry
from tenacity.stop import stop_after_attempt
Expand All @@ -23,7 +25,11 @@
from camel.messages import BaseMessage
from camel.models import BaseModelBackend, ModelFactory
from camel.typing import ModelType, RoleType
from camel.utils import num_tokens_from_messages, openai_api_key_required
from camel.utils import (
get_model_encoding,
num_tokens_from_messages,
openai_api_key_required,
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -197,8 +203,8 @@ def update_messages(self, role: str,
return self.stored_messages

def submit_message(self, message: BaseMessage) -> None:
r"""Submits the externaly provided message as if it were an answer of
the chat LLM from the backend. Currently the choise of the critic is
r"""Submits the externally provided message as if it were an answer of
the chat LLM from the backend. Currently, the choice of the critic is
submitted with this method.
Args:
Expand All @@ -223,10 +229,9 @@ def step(
for the self agent any incoming message is external.
Returns:
ChatAgentResponse: A struct
containing the output messages, a boolean indicating whether
the chat session has terminated, and information about the chat
session.
ChatAgentResponse: A struct containing the output messages,
a boolean indicating whether the chat session has terminated,
and information about the chat session.
"""
messages = self.update_messages('user', input_message)
if self.message_window_size is not None and len(
Expand All @@ -241,21 +246,17 @@ def step(

if num_tokens < self.model_token_limit:
response = self.model_backend.run(openai_messages)
if not isinstance(response, dict):
raise RuntimeError("OpenAI returned unexpected struct")
output_messages = [
BaseMessage(role_name=self.role_name, role_type=self.role_type,
meta_dict=dict(),
content=choice["message"]['content'])
for choice in response["choices"]
]
self.validate_model_response(response)
if not self.model_backend.stream:
output_messages, finish_reasons, usage_dict, response_id = \
self.handle_batch_response(response)
else:
output_messages, finish_reasons, usage_dict, response_id = \
self.handle_stream_response(response, num_tokens)
info = self.get_info(
response["id"],
response["usage"],
[
str(choice["finish_reason"])
for choice in response["choices"]
],
response_id,
usage_dict,
finish_reasons,
num_tokens,
)
else:
Expand All @@ -271,6 +272,104 @@ def step(

return ChatAgentResponse(output_messages, self.terminated, info)

def validate_model_response(self, response: Any):
if not self.model_backend.stream:
if not isinstance(response, dict):
raise RuntimeError("OpenAI returned unexpected batch struct")
else:
if not isinstance(response, GeneratorType):
raise RuntimeError("OpenAI returned unexpected stream struct")

def handle_batch_response(
self, response: Dict[str, Any]
) -> Tuple[List[BaseMessage], List[str], Dict[str, int], str]:
r"""
Args:
response (dict): Model response.
Returns:
tuple: A tuple of list of output `ChatMessage`, list of
finish reasons, usage dictionary, and response id.
"""
output_messages: List[BaseMessage] = []
for choice in response["choices"]:
chat_message = BaseMessage(role_name=self.role_name,
role_type=self.role_type,
meta_dict=dict(),
content=choice["message"]['content'])
output_messages.append(chat_message)
finish_reasons = [
str(choice["finish_reason"]) for choice in response["choices"]
]
return output_messages, finish_reasons, dict(
response["usage"]), response["id"]

def handle_stream_response(
self,
response: Any,
prompt_tokens: int,
) -> Tuple[List[BaseMessage], List[str], Dict[str, int], str]:
r"""
Args:
response (dict): Model response.
prompt_tokens (int): Number of input prompt tokens.
Returns:
tuple: A tuple of list of output `ChatMessage`, list of
finish reasons, usage dictionary, and response id.
"""
content_dict = defaultdict(lambda: "")
finish_reasons = defaultdict(lambda: "")
output_messages: List[BaseMessage] = []
response_id: str = ""
# All choices in one response share one role
role: str = ""
for chunk in response:
response_id = chunk["id"]
for choice in chunk["choices"]:
index = choice["index"]
delta = choice["delta"]
if len(delta) != 0:
# When response has not been stopped
# Notice that only the first chunk has the "role"
role = delta.get("role", role)
delta_content = delta.get("content", "")
content_dict[index] += delta_content
else:
finish_reasons[index] = choice["finish_reason"]
chat_message = BaseMessage(role_name=self.role_name,
role_type=self.role_type,
meta_dict=dict(),
content=content_dict[index])
output_messages.append(chat_message)
finish_reasons = [
finish_reasons[i] for i in range(len(finish_reasons))
]
usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
return output_messages, finish_reasons, usage_dict, response_id

def get_usage_dict(self, output_messages: List[BaseMessage],
prompt_tokens: int) -> Dict[str, int]:
r"""Get usage dictionary when using the stream mode.
Args:
output_messages (list): List of output messages.
prompt_tokens (int): Number of input prompt tokens.
Returns:
dict: Usage dictionary.
"""
encoding = get_model_encoding(self.model.value_for_tiktoken)
completion_tokens = 0
for message in output_messages:
completion_tokens += len(encoding.encode(message.content))
usage_dict = dict(completion_tokens=completion_tokens,
prompt_tokens=prompt_tokens,
total_tokens=completion_tokens + prompt_tokens)
return usage_dict

def __repr__(self) -> str:
r"""Returns a string representation of the :obj:`ChatAgent`.
Expand Down
9 changes: 9 additions & 0 deletions camel/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,12 @@ def token_limit(self) -> int:
int: The maximum token limit for the given model.
"""
return self.model_type.token_limit

@property
def stream(self) -> bool:
r"""Returns whether the model is in stream mode,
which sends partial results each time.
Returns:
bool: Whether the model is in stream mode.
"""
return False
18 changes: 16 additions & 2 deletions camel/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from types import GeneratorType
from typing import Any, Dict, List

from camel.messages import OpenAIMessage
Expand Down Expand Up @@ -48,6 +49,19 @@ def run(self, messages: List[Dict]) -> Dict[str, Any]:
response = openai.ChatCompletion.create(messages=messages_openai,
model=self.model_type.value,
**self.model_config_dict)
if not isinstance(response, Dict):
raise RuntimeError("Unexpected return from OpenAI API")
if not self.stream:
if not isinstance(response, Dict):
raise RuntimeError("Unexpected batch return from OpenAI API")
else:
if not isinstance(response, GeneratorType):
raise RuntimeError("Unexpected stream return from OpenAI API")
return response

@property
def stream(self) -> bool:
r"""Returns whether the model is in stream mode,
which sends partial results each time.
Returns:
bool: Whether the model is in stream mode.
"""
return self.model_config_dict.get('stream', False)
75 changes: 51 additions & 24 deletions camel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@
F = TypeVar('F', bound=Callable[..., Any])


def get_model_encoding(value_for_tiktoken: str):
r"""Get model encoding from tiktoken.
Args:
value_for_tiktoken: Model value for tiktoken.
Returns:
tiktoken.Encoding: Model encoding.
"""
try:
encoding = tiktoken.encoding_for_model(value_for_tiktoken)
except KeyError:
print("Model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
return encoding


def count_tokens_openai_chat_models(
messages: List[OpenAIMessage],
encoding: tiktoken.Encoding,
Expand All @@ -39,6 +56,10 @@ def count_tokens_openai_chat_models(
Args:
messages (List[OpenAIMessage]): The list of messages.
encoding (tiktoken.Encoding): The encoding method to use.
tokens_per_message (int): Number of tokens to be added
to each message.
tokens_per_name (int): Number of tokens to be added if
name existed in the message.
Returns:
int: The number of tokens required.
Expand All @@ -48,7 +69,7 @@ def count_tokens_openai_chat_models(
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
Expand Down Expand Up @@ -76,32 +97,35 @@ def num_tokens_from_messages(
- https://platform.openai.com/docs/models/gpt-4
- https://platform.openai.com/docs/models/gpt-3-5
"""
try:
value_for_tiktoken = model.value_for_tiktoken
encoding = tiktoken.encoding_for_model(value_for_tiktoken)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
return _num_tokens_from_messages(messages, model.value_for_tiktoken)

if model.value_for_tiktoken.startswith("gpt-3.5-turbo"):

# flake8: noqa :E501
def _num_tokens_from_messages(messages: List[OpenAIMessage], model: str):
r"""Return the number of tokens used by a list of messages.
References:
- https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
"""
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
# Every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_message = 4
# If there's a name, the role is omitted
tokens_per_name = -1
return count_tokens_openai_chat_models(
messages,
encoding,
tokens_per_message,
tokens_per_name,
)
elif model.value_for_tiktoken.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
return count_tokens_openai_chat_models(
messages,
encoding,
tokens_per_message,
tokens_per_name,
)
elif "gpt-3.5-turbo" in model:
return _num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
return _num_tokens_from_messages(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"`num_tokens_from_messages`` is not presently implemented "
Expand All @@ -111,6 +135,9 @@ def num_tokens_from_messages(
f"See https://platform.openai.com/docs/models/gpt-4"
f"or https://platform.openai.com/docs/models/gpt-3-5"
f"for information about openai chat models.")
encoding = get_model_encoding(model)
return count_tokens_openai_chat_models(messages, encoding,
tokens_per_message, tokens_per_name)


def openai_api_key_required(func: F) -> F:
Expand Down Expand Up @@ -150,8 +177,8 @@ def print_text_animated(text, delay: float = 0.02, end: str = ""):
text (str): The text to print.
delay (float, optional): The delay between each character printed.
(default: :obj:`0.02`)
end (str, optional): The end character to print after the text.
(default: :obj:`""`)
end (str, optional): The end character to print after each
character of text. (default: :obj:`""`)
"""
for char in text:
print(char, end=end, flush=True)
Expand Down
Loading

0 comments on commit e92705f

Please sign in to comment.