Skip to content

Commit

Permalink
Merge branch 'master' into omegaPRM_openR
Browse files Browse the repository at this point in the history
  • Loading branch information
zjrwtx committed Dec 13, 2024
2 parents 5064dd0 + 33c2787 commit f60e69b
Show file tree
Hide file tree
Showing 46 changed files with 3,060 additions and 184 deletions.
3 changes: 3 additions & 0 deletions camel/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
SambaCloudAPIConfig,
SambaVerseAPIConfig,
)
from .sglang_config import SGLANG_API_PARAMS, SGLangConfig
from .togetherai_config import TOGETHERAI_API_PARAMS, TogetherAIConfig
from .vllm_config import VLLM_API_PARAMS, VLLMConfig
from .yi_config import YI_API_PARAMS, YiConfig
Expand All @@ -55,6 +56,8 @@
'Gemini_API_PARAMS',
'VLLMConfig',
'VLLM_API_PARAMS',
'SGLangConfig',
'SGLANG_API_PARAMS',
'MistralConfig',
'MISTRAL_API_PARAMS',
'RekaConfig',
Expand Down
6 changes: 4 additions & 2 deletions camel/configs/ollama_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations

from typing import Sequence, Union
from typing import Sequence, Type, Union

from pydantic import BaseModel

from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
Expand Down Expand Up @@ -75,7 +77,7 @@ class OllamaConfig(BaseConfig):
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[dict, NotGiven] = NOT_GIVEN
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0


Expand Down
71 changes: 71 additions & 0 deletions camel/configs/sglang_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations

from typing import Sequence, Union

from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven


class SGLangConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
OpenAI API.
Reference: https://sgl-project.github.io/references/sampling_params.html
Args:
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`1.0`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
n (int, optional): How many chat completion choices to generate for
each input message. (default: :obj:`1`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
stream (bool, optional): Whether to stream the generated output in
chunks. If set to `True`, the response will be streamed as it is
generated. (default: :obj:`False`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
"""

stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
temperature: float = 1.0
top_p: float = 1.0
n: int = 1
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
stream: bool = False
max_tokens: Union[int, NotGiven] = NOT_GIVEN


SGLANG_API_PARAMS = {param for param in SGLangConfig.model_fields.keys()}
19 changes: 19 additions & 0 deletions camel/data_collector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

from .alpaca_collector import AlpacaDataCollector
from .base import BaseDataCollector
from .sharegpt_collector import ShareGPTDataCollector

__all__ = ["BaseDataCollector", "AlpacaDataCollector", "ShareGPTDataCollector"]
127 changes: 127 additions & 0 deletions camel/data_collector/alpaca_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

from typing import Any, Dict, List, Optional, Union

from typing_extensions import Self

from camel.agents import ChatAgent
from camel.data_collector.base import BaseDataCollector
from camel.messages import AlpacaItem, BaseMessage
from camel.schemas import OpenAISchemaConverter

# ruff: noqa: E501
DEFAULT_CONVERTER_PROMPTS = """
Extract key entities and attributes from the conversations
and convert them into a structured JSON format.
For example:
Instruction: You are a helpful assistant.
User: When is the release date of the video game Portal?
Assistant: The release date of the video game Portal is October 9.
Your output should be:
{
"instruction": "You are a helpful assistant. When is the release date of the video game Portal?",
"input": "",
"output": "The release date of the video game Portal is October 9."
}
"""


class AlpacaDataCollector(BaseDataCollector):
def __init__(self) -> None:
super().__init__()
self.system_message: Optional[BaseMessage] = None
self.agent_name: Optional[str] = None

def record(
self,
agent: Union[List[ChatAgent], ChatAgent],
) -> Self:
r"""Inject an agent into the data collector.
Args:
agent (Union[List[ChatAgent], ChatAgent]):
The agent to inject.
"""
if not self.agent_name:
_agent = agent if isinstance(agent, ChatAgent) else agent[0]
self.agent_name = _agent.role_name
self.system_message = _agent._system_message
super().record(agent)
return self

def convert(self) -> Dict[str, Any]:
r"""Convert the collected data into a dictionary."""
if self.agent_name is None:
raise ValueError("No agent injected")

history = self.get_agent_history(self.agent_name)
if not history:
raise ValueError("No data collected.")

# Validate and process history
if len(history) == 3 and history[0].role == "system":
history = history[1:] # Ignore the system message.
elif len(history) != 2:
raise ValueError(
f"AlpacaDataCollector only supports one message pair, but "
f"got {len(history)}"
)

input_message, output_message = history
instruction = (
self.system_message.content if self.system_message else ""
) + str(input_message.message)

data = {
"instruction": instruction,
"input": "",
"output": output_message.message,
}
self.data.append(data)
return data

def llm_convert(
self,
converter: Optional[OpenAISchemaConverter] = None,
prompt: Optional[str] = None,
) -> Dict[str, str]:
r"""Convert collected data using an LLM schema converter.
Args:
converter (Optional[OpenAISchemaConverter], optional):
The converter to use. (default: :obj:`OpenAISchemaConverter`)
prompt (Optional[str], optional): Prompt to guide the conversion.
(default: :obj:`DEFAULT_CONVERTER_PROMPTS`)
Returns:
Dict[str, str]: The converted data.
Raises:
ValueError: If no agent is injected or data cannot be collected.
"""
prompt = prompt or DEFAULT_CONVERTER_PROMPTS
converter = converter or OpenAISchemaConverter()

system = self.system_message.content if self.system_message else ""
context = [f"Instruction: {system}\n"]

for message in self.get_agent_history(str(self.agent_name)):
if message.role == "user":
context.append(f"User: {message.message}\n")
else:
context.append(f"{message.name}: {message.message}\n")
return converter.convert(
"\n".join(context), AlpacaItem, prompt=prompt
).model_dump()
Loading

0 comments on commit f60e69b

Please sign in to comment.