Skip to content

Commit

Permalink
add examples
Browse files Browse the repository at this point in the history
  • Loading branch information
zjrwtx committed Dec 13, 2024
1 parent 33c2787 commit 12ec2ec
Show file tree
Hide file tree
Showing 6 changed files with 501 additions and 0 deletions.
5 changes: 5 additions & 0 deletions examples/omegaPRM_openR/.env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# When customizing the model, you need to set the model to call some model...like # When customizing the model, you need to set the model to call deepseek-chat model

OPENAI_COMPATIBILIY_ModelType=deepseek-chat
OPENAI_COMPATIBILIY_API_BASE_URL=https://api.deepseek.com
OPENAI_COMPATIBILIY_API_KEY=
25 changes: 25 additions & 0 deletions examples/omegaPRM_openR/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
input:
json_file_path: 'example_problems.json'

output:
file_prefix: 'example'
log_file_path: 'example_processing.log'

processing:
initial_rollouts: 20
num_rollouts: 20
max_iterations: 100

model:
model_type: "camel"
model_name: "deepseek-chat"
model_args:
max_tokens: 200
temperature_range: [0.7, 1.0]


# There are 32 initial scroll execution questions to choose from.
# If the questions are too easy (all right) or too hard (all wrong),
# then OmegaPRM will not deal with them.
# So you may need a problem set of the right difficulty and a model that is not too strong,
# such as qwen's open source small model. Then the node node data may be obtained
6 changes: 6 additions & 0 deletions examples/omegaPRM_openR/example_problems.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[
{
"problem": "How many ways can we put 3 math books and 5 English books on a shelf if all the math books must stay together and all the English books must also stay together? (The math books are all different and so are the English books.)",
"final_answer": "1440"
}
]
146 changes: 146 additions & 0 deletions examples/omegaPRM_openR/gen_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
import logging

import yaml
from model_utils import LM
from module import (
Node,
calculate_mc_score,
perform_rollouts,
process_annotations,
)


def load_config(config_path):
"""
Load configuration from a YAML file.
Args:
config_path (str): Path to the YAML configuration file.
Returns:
dict: A dictionary containing the configuration.
"""
with open(config_path, 'r') as file:
return yaml.safe_load(file)


def load_json_file(file_path):
"""
Load data from a JSON file.
Args:
file_path (str): Path to the JSON file.
Returns:
list: A list of dictionaries containing the problem and final answer.
"""
with open(file_path, 'r') as file:
data = json.load(file)
return data


def setup_logging(log_file):
"""
Set up logging configuration to output to file and console.
Args:
log_file (str): Path to the log file.
"""
logging.basicConfig(
filename=log_file,
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)

console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
console_handler.setFormatter(formatter)

root_logger = logging.getLogger()
root_logger.addHandler(console_handler)

logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.WARNING)


def main():
# Load configuration
config = load_config('config.yaml')

# Get parameters from config
json_file_path = config['input']['json_file_path']
log_file_path = config['output']['log_file_path']
file_prefix = config['output']['file_prefix']
num_rollouts = config['processing']['num_rollouts']
initial_rollouts = config['processing']['initial_rollouts']
max_iterations = config['processing']['max_iterations']

lm_model = LM(
model_type=config['model']['model_type'],
model_name=config['model']['model_name'],
num_rollouts=num_rollouts,
**config['model']['model_args'],
)

# Set up logging
setup_logging(log_file_path)

# Start the process and log it
logging.info("Started processing the JSON file.")

# Load the JSON data
data = load_json_file(json_file_path)

# Process each problem and its final answer
for i, item in enumerate(data):
problem = item.get('problem', 'No problem found')
final_answer = item.get('final_answer', 'No answer found')

# Log each problem and answer
logging.info(f"Processed Problem {i + 1}: {problem}")
logging.info(f"Final Answer: {final_answer}")

# Initialize the root node and perform rollouts
nodes = []
root_node = Node(problem, "", final_answer)
rollouts, correctness_flags = perform_rollouts(
root_node, lm_model, initial_rollouts
)
mc_score = calculate_mc_score(root_node)
root_node.mc_score = mc_score

nodes.append(root_node)

# Check if further processing is needed
if 0 < sum(correctness_flags) < initial_rollouts:
print("Processing annotations ...\n")
filename = f"{file_prefix}_{i+1}_nodes_data.json"
process_annotations(
problem, nodes, lm_model, filename, max_iterations
)

# Log completion
logging.info("Finished processing the JSON file.")


if __name__ == "__main__":
main()
96 changes: 96 additions & 0 deletions examples/omegaPRM_openR/model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import os
import random

from dotenv import load_dotenv
from tqdm import tqdm

from camel.agents import ChatAgent
from camel.models import ModelFactory
from camel.types import ModelPlatformType

load_dotenv()


class LM:
def __init__(self, model_type, model_name, num_rollouts=5, **kwargs):
self.model_type = model_type
self.model_name = model_name
self.num_rollouts = num_rollouts
self.max_tokens = kwargs.get('max_tokens', 4096)
self.temperature_range = kwargs.get('temperature_range', [0.7, 1.0])

if self.model_type != "camel":
raise ValueError("Only camel model type is supported")

# Initialize camel model
self.model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI_COMPATIBLE_MODEL,
model_type=os.environ.get("OPENAI_COMPATIBILIY_ModelType"),
api_key=os.environ.get("OPENAI_COMPATIBILIY_API_KEY"),
url=os.environ.get("OPENAI_COMPATIBILIY_API_BASE_URL"),
model_config_dict={
"temperature": random.uniform(*self.temperature_range),
"max_tokens": self.max_tokens,
},
)

# Initialize chat agent
self.agent = ChatAgent(
system_message='''You are a mathematical reasoning expert who
always solves problems step by step.
For each step:
1. Write down what you're calculating
2. Show the calculation
3. Explain the result
Always show your work, even for simple calculations.
End your solution with the final numerical answer.''',
model=self.model,
message_window_size=10,
)

def generate(self, question, partial_answer, num_rollouts=None):
results = []
if num_rollouts is None:
num_rollouts = self.num_rollouts

for _ in tqdm(range(num_rollouts)):
# Update temperature for each rollout
self.model.model_config_dict["temperature"] = random.uniform(
*self.temperature_range
)

# Construct the prompt
if partial_answer:
prompt = f"""Problem: {question}
Current solution steps:
{partial_answer}
Continue the solution, showing all steps and calculations.
Make sure to explain each step:"""
else:
prompt = f"""Problem: {question}
Please solve this step by step, showing all calculations and
explaining each step.
Remember to:
1. Break down the problem
2. Show all calculations
3. Explain each step
4. End with the final numerical answer."""

# Get response from agent
response = self.agent.step(prompt)
results.append(response.msgs[0].content)

return results
Loading

0 comments on commit 12ec2ec

Please sign in to comment.