Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created models subpackage #167

Merged
merged 9 commits into from
Jun 17, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from camel.agents import BaseAgent
from camel.configs import ChatGPTConfig
from camel.messages import ChatMessage, MessageType, SystemMessage
from camel.model_backend import ModelBackend, ModelFactory
from camel.models import BaseModel, ModelFactory
from camel.typing import ModelType, RoleType
from camel.utils import (
get_model_token_limit,
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(
self.model_token_limit: int = get_model_token_limit(self.model)
self.message_window_size: Optional[int] = message_window_size

self.model_backend: ModelBackend = ModelFactory.create(
self.model_backend: BaseModel = ModelFactory.create(
self.model, self.model_config.__dict__)

self.terminated: bool = False
Expand Down Expand Up @@ -180,7 +180,7 @@ def step(
info: Dict[str, Any]

if num_tokens < self.model_token_limit:
response = self.model_backend.run(messages=openai_messages)
response = self.model_backend.run(openai_messages)
if not isinstance(response, dict):
raise RuntimeError("OpenAI returned unexpected struct")
output_messages = [
Expand Down
101 changes: 0 additions & 101 deletions camel/model_backend.py
Obs01ete marked this conversation as resolved.
Outdated
Show resolved Hide resolved

This file was deleted.

24 changes: 24 additions & 0 deletions camel/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from .base_model import BaseModel
from .openai_model import OpenAIModel
from .stub_model import StubModel
from .model_factory import ModelFactory

__all__ = [
'BaseModel',
'OpenAIModel',
'StubModel',
'ModelFactory',
]
37 changes: 37 additions & 0 deletions camel/models/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from abc import ABC, abstractmethod
from typing import Any, Dict, List


class BaseModel(ABC):
Obs01ete marked this conversation as resolved.
Show resolved Hide resolved
r"""Base class for different model backends.
May be OpenAI API, a local LLM, a stub for unit tests, etc."""

@abstractmethod
def run(self, messages: List[Dict]) -> Dict[str, Any]:
r"""Runs the query to the backend model.

Args:
messages (List[Dict]): message list with the chat history
in OpenAI API format.

Raises:
RuntimeError: if the return value from OpenAI API
is not a dict that is expected.

Returns:
Dict[str, Any]: All backends must return a dict in OpenAI format.
"""
pass
55 changes: 55 additions & 0 deletions camel/models/model_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from typing import Any, Dict

from camel.models.base_model import BaseModel
from camel.models.openai_model import OpenAIModel
from camel.models.stub_model import StubModel
Obs01ete marked this conversation as resolved.
Show resolved Hide resolved
from camel.typing import ModelType


class ModelFactory:
r"""Factory of backend models.

Raises:
ValueError: in case the provided model type is unknown.
"""

@staticmethod
def create(model_type: ModelType, model_config_dict: Dict) -> BaseModel:
r"""Creates an instance of `BaseModel` of the specified type.

Args:
model_type (ModelType): Model for which a backend is created.
model_config_dict (Dict): a dictionary that will be fed into
the backend constructor.

Raises:
ValueError: If there is not backend for the model.

Returns:
BaseModel: The initialized backend.
"""
model_class: Any
if model_type in {
ModelType.GPT_3_5_TURBO, ModelType.GPT_4, ModelType.GPT_4_32k
}:
model_class = OpenAIModel
elif model_type == ModelType.STUB:
model_class = StubModel
else:
raise ValueError("Unknown model")

inst = model_class(model_type, model_config_dict)
return inst
55 changes: 55 additions & 0 deletions camel/models/openai_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from typing import Any, Dict, List

from camel.messages import OpenAIMessage
from camel.models.base_model import BaseModel
from camel.typing import ModelType


class OpenAIModel(BaseModel):
r"""OpenAI API in a unified BaseModel interface."""

def __init__(self, model_type: ModelType,
model_config_dict: Dict[str, Any]) -> None:
r"""Constructor for OpenAI backend.

Args:
model_type (ModelType): Model for which a backend is created,
one of GPT_* series.
model_config_dict (Dict[str, Any]): a dictionary that will
be fed into openai.ChatCompletion.create().
"""
super().__init__()
self.model_type = model_type
self.model_config_dict = model_config_dict

def run(self, messages: List[Dict]) -> Dict[str, Any]:
r"""Run inference of OpenAI chat completion.

Args:
messages (List[Dict]): message list with the chat history
in OpenAI API format.

Returns:
Dict[str, Any]: Response in the OpenAI API format.
"""
import openai
messages_openai: List[OpenAIMessage] = messages
response = openai.ChatCompletion.create(messages=messages_openai,
model=self.model_type.value,
**self.model_config_dict)
if not isinstance(response, Dict):
raise RuntimeError("Unexpected return from OpenAI API")
return response
42 changes: 42 additions & 0 deletions camel/models/stub_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from typing import Any, Dict, List

from camel.models.base_model import BaseModel
Obs01ete marked this conversation as resolved.
Show resolved Hide resolved


class StubModel(BaseModel):
r"""A dummy model used for unit tests."""

def __init__(self, *args, **kwargs) -> None:
r"""All arguments are unused for the dummy model."""
super().__init__()

def run(self, messages: List[Dict]) -> Dict[str, Any]:
r"""Run fake inference by returning a fixed string.
All arguments are unused for the dummy model.

Returns:
Dict[str, Any]: Response in the OpenAI API format.
"""
ARBITRARY_STRING = "Lorem Ipsum"

return dict(
id="stub_model_id",
usage=dict(),
choices=[
dict(finish_reason="stop",
message=dict(content=ARBITRARY_STRING, role="assistant"))
],
)
4 changes: 0 additions & 4 deletions examples/test/test_ai_society_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,3 @@
@pytest.mark.slow
def test_ai_society_role_playing_example():
examples.ai_society.role_playing.main(ModelType.STUB)


if __name__ == "__main__":
pytest.main()
5 changes: 2 additions & 3 deletions test/test_model_backend.py
Obs01ete marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
import pytest

from camel.configs import ChatGPTConfig
from camel.model_backend import ModelFactory
from camel.models import ModelFactory
from camel.typing import ModelType

parametrize = pytest.mark.parametrize('model', [
ModelType.STUB,
pytest.param(None, marks=pytest.mark.model_backend),
pytest.param(ModelType.GPT_3_5_TURBO, marks=pytest.mark.model_backend),
pytest.param(ModelType.GPT_4, marks=pytest.mark.model_backend),
])
Expand All @@ -45,7 +44,7 @@ def test_openai_model(model):
'Do not add anything else.')
},
]
response = model_inst.run(messages=messages)
response = model_inst.run(messages)
assert isinstance(response, dict)
assert 'id' in response
assert isinstance(response['id'], str)
Expand Down