Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created models subpackage #167

Merged
merged 9 commits into from
Jun 17, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def step(
info: Dict[str, Any]

if num_tokens < self.model_token_limit:
response = self.model_backend.run(messages=openai_messages)
response = self.model_backend.run(openai_messages)
if not isinstance(response, dict):
raise RuntimeError("OpenAI returned unexpected struct")
output_messages = [
Expand Down
54 changes: 48 additions & 6 deletions camel/model_backend.py
Obs01ete marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from abc import ABC, abstractmethod
from typing import Any, Dict
from typing import Any, Dict, List, Optional

from camel.typing import ModelType

Expand All @@ -22,9 +22,13 @@ class ModelBackend(ABC):
May be OpenAI API, a local LLM, a stub for unit tests, etc."""

@abstractmethod
def run(self, *args, **kwargs) -> Dict[str, Any]:
def run(self, messages: List[Dict]) -> Dict[str, Any]:
r"""Runs the query to the backend model.

Args:
messages (List[Dict]): message list with the chat history
in OpenAI API format.

Raises:
RuntimeError: if the return value from OpenAI API
is not a dict that is expected.
Expand All @@ -39,14 +43,30 @@ class OpenAIModel(ModelBackend):
r"""OpenAI API in a unified ModelBackend interface."""

def __init__(self, model_type: ModelType, model_config_dict: Dict) -> None:
Obs01ete marked this conversation as resolved.
Show resolved Hide resolved
r"""Constructor for OpenAI backend.

Args:
model_type (ModelType): Model for which a backend is created,
one of GPT_* series.
model_config_dict (Dict): a dictionary that will be fed into
openai.ChatCompletion.create().
"""
super().__init__()
self.model_type = model_type
self.model_config_dict = model_config_dict

def run(self, *args, **kwargs) -> Dict[str, Any]:
def run(self, messages: List[Dict]) -> Dict[str, Any]:
Obs01ete marked this conversation as resolved.
Show resolved Hide resolved
r"""Run inference of OpenAI chat completion.

Args:
messages (List[Dict]): message list with the chat history
in OpenAI API format.

Returns:
Dict[str, Any]: Response in the OpenAI API format.
"""
import openai
response = openai.ChatCompletion.create(*args, **kwargs,
response = openai.ChatCompletion.create(messages=messages,
model=self.model_type.value,
**self.model_config_dict)
if not isinstance(response, Dict):
Expand All @@ -58,9 +78,16 @@ class StubModel(ModelBackend):
r"""A dummy model used for unit tests."""

def __init__(self, *args, **kwargs) -> None:
Obs01ete marked this conversation as resolved.
Show resolved Hide resolved
r"""All arguments are unused for the dummy model."""
super().__init__()

def run(self, *args, **kwargs) -> Dict[str, Any]:
def run(self, messages: List[Dict]) -> Dict[str, Any]:
r"""Run fake inference by returning a fixed string.
All arguments are unused for the dummy model.

Returns:
Dict[str, Any]: Response in the OpenAI API format.
"""
ARBITRARY_STRING = "Lorem Ipsum"

return dict(
Expand All @@ -81,9 +108,24 @@ class ModelFactory:
"""

@staticmethod
def create(model_type: ModelType, model_config_dict: Dict) -> ModelBackend:
def create(model_type: Optional[ModelType],
model_config_dict: Dict) -> ModelBackend:
Obs01ete marked this conversation as resolved.
Show resolved Hide resolved
r"""Creates an instance of `ModelBackend` of the specified type.

Args:
model_type (ModelType): Model for which a backend is created.
model_config_dict (Dict): a dictionary that will be fed into
the backend constructor.

Raises:
ValueError: If there is not backend for the model.

Returns:
ModelBackend: The initialized backend.
"""
default_model_type = ModelType.GPT_3_5_TURBO
Obs01ete marked this conversation as resolved.
Show resolved Hide resolved

model_class: Any
if model_type in {
ModelType.GPT_3_5_TURBO, ModelType.GPT_4, ModelType.GPT_4_32k,
None
Obs01ete marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
4 changes: 0 additions & 4 deletions examples/test/test_ai_society_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,3 @@
@pytest.mark.slow
def test_ai_society_role_playing_example():
examples.ai_society.role_playing.main(ModelType.STUB)


if __name__ == "__main__":
pytest.main()
2 changes: 1 addition & 1 deletion test/test_model_backend.py
Obs01ete marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_openai_model(model):
'Do not add anything else.')
},
]
response = model_inst.run(messages=messages)
response = model_inst.run(messages)
assert isinstance(response, dict)
assert 'id' in response
assert isinstance(response['id'], str)
Expand Down