Skip to content

Commit

Permalink
Created models subpackage (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
Obs01ete committed Jun 17, 2023
1 parent 426c028 commit 6a27e23
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 112 deletions.
6 changes: 3 additions & 3 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from camel.agents import BaseAgent
from camel.configs import ChatGPTConfig
from camel.messages import ChatMessage, MessageType, SystemMessage
from camel.model_backend import ModelBackend, ModelFactory
from camel.models import BaseModelBackend, ModelFactory
from camel.typing import ModelType, RoleType
from camel.utils import (
get_model_token_limit,
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(
self.model_token_limit: int = get_model_token_limit(self.model)
self.message_window_size: Optional[int] = message_window_size

self.model_backend: ModelBackend = ModelFactory.create(
self.model_backend: BaseModelBackend = ModelFactory.create(
self.model, self.model_config.__dict__)

self.terminated: bool = False
Expand Down Expand Up @@ -180,7 +180,7 @@ def step(
info: Dict[str, Any]

if num_tokens < self.model_token_limit:
response = self.model_backend.run(messages=openai_messages)
response = self.model_backend.run(openai_messages)
if not isinstance(response, dict):
raise RuntimeError("OpenAI returned unexpected struct")
output_messages = [
Expand Down
101 changes: 0 additions & 101 deletions camel/model_backend.py

This file was deleted.

24 changes: 24 additions & 0 deletions camel/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from .base_model import BaseModelBackend
from .openai_model import OpenAIModel
from .stub_model import StubModel
from .model_factory import ModelFactory

__all__ = [
'BaseModelBackend',
'OpenAIModel',
'StubModel',
'ModelFactory',
]
37 changes: 37 additions & 0 deletions camel/models/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from abc import ABC, abstractmethod
from typing import Any, Dict, List


class BaseModelBackend(ABC):
r"""Base class for different model backends.
May be OpenAI API, a local LLM, a stub for unit tests, etc."""

@abstractmethod
def run(self, messages: List[Dict]) -> Dict[str, Any]:
r"""Runs the query to the backend model.
Args:
messages (List[Dict]): message list with the chat history
in OpenAI API format.
Raises:
RuntimeError: if the return value from OpenAI API
is not a dict that is expected.
Returns:
Dict[str, Any]: All backends must return a dict in OpenAI format.
"""
pass
54 changes: 54 additions & 0 deletions camel/models/model_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from typing import Any, Dict

from camel.models import BaseModelBackend, OpenAIModel, StubModel
from camel.typing import ModelType


class ModelFactory:
r"""Factory of backend models.
Raises:
ValueError: in case the provided model type is unknown.
"""

@staticmethod
def create(model_type: ModelType,
model_config_dict: Dict) -> BaseModelBackend:
r"""Creates an instance of `BaseModelBackend` of the specified type.
Args:
model_type (ModelType): Model for which a backend is created.
model_config_dict (Dict): a dictionary that will be fed into
the backend constructor.
Raises:
ValueError: If there is not backend for the model.
Returns:
BaseModelBackend: The initialized backend.
"""
model_class: Any
if model_type in {
ModelType.GPT_3_5_TURBO, ModelType.GPT_4, ModelType.GPT_4_32k
}:
model_class = OpenAIModel
elif model_type == ModelType.STUB:
model_class = StubModel
else:
raise ValueError("Unknown model")

inst = model_class(model_type, model_config_dict)
return inst
55 changes: 55 additions & 0 deletions camel/models/openai_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from typing import Any, Dict, List

from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.typing import ModelType


class OpenAIModel(BaseModelBackend):
r"""OpenAI API in a unified BaseModelBackend interface."""

def __init__(self, model_type: ModelType,
model_config_dict: Dict[str, Any]) -> None:
r"""Constructor for OpenAI backend.
Args:
model_type (ModelType): Model for which a backend is created,
one of GPT_* series.
model_config_dict (Dict[str, Any]): a dictionary that will
be fed into openai.ChatCompletion.create().
"""
super().__init__()
self.model_type = model_type
self.model_config_dict = model_config_dict

def run(self, messages: List[Dict]) -> Dict[str, Any]:
r"""Run inference of OpenAI chat completion.
Args:
messages (List[Dict]): message list with the chat history
in OpenAI API format.
Returns:
Dict[str, Any]: Response in the OpenAI API format.
"""
import openai
messages_openai: List[OpenAIMessage] = messages
response = openai.ChatCompletion.create(messages=messages_openai,
model=self.model_type.value,
**self.model_config_dict)
if not isinstance(response, Dict):
raise RuntimeError("Unexpected return from OpenAI API")
return response
42 changes: 42 additions & 0 deletions camel/models/stub_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from typing import Any, Dict, List

from camel.models import BaseModelBackend


class StubModel(BaseModelBackend):
r"""A dummy model used for unit tests."""

def __init__(self, *args, **kwargs) -> None:
r"""All arguments are unused for the dummy model."""
super().__init__()

def run(self, messages: List[Dict]) -> Dict[str, Any]:
r"""Run fake inference by returning a fixed string.
All arguments are unused for the dummy model.
Returns:
Dict[str, Any]: Response in the OpenAI API format.
"""
ARBITRARY_STRING = "Lorem Ipsum"

return dict(
id="stub_model_id",
usage=dict(),
choices=[
dict(finish_reason="stop",
message=dict(content=ARBITRARY_STRING, role="assistant"))
],
)
4 changes: 0 additions & 4 deletions examples/test/test_ai_society_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,3 @@
@pytest.mark.slow
def test_ai_society_role_playing_example():
examples.ai_society.role_playing.main(ModelType.STUB)


if __name__ == "__main__":
pytest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,18 @@
import pytest

from camel.configs import ChatGPTConfig
from camel.model_backend import ModelFactory
from camel.models import ModelFactory
from camel.typing import ModelType

parametrize = pytest.mark.parametrize('model', [
ModelType.STUB,
pytest.param(None, marks=pytest.mark.model_backend),
pytest.param(ModelType.GPT_3_5_TURBO, marks=pytest.mark.model_backend),
pytest.param(ModelType.GPT_4, marks=pytest.mark.model_backend),
])


@parametrize
def test_openai_model(model):
def test_model_factory(model):
model_config_dict = ChatGPTConfig().__dict__
model_inst = ModelFactory.create(model, model_config_dict)
messages = [
Expand All @@ -45,7 +44,7 @@ def test_openai_model(model):
'Do not add anything else.')
},
]
response = model_inst.run(messages=messages)
response = model_inst.run(messages)
assert isinstance(response, dict)
assert 'id' in response
assert isinstance(response['id'], str)
Expand Down

0 comments on commit 6a27e23

Please sign in to comment.