diff --git a/examples/omegaPRM_openR/.env.example b/examples/omegaPRM_openR/.env.example new file mode 100644 index 0000000000..49f6eb1d68 --- /dev/null +++ b/examples/omegaPRM_openR/.env.example @@ -0,0 +1,5 @@ +# When customizing the model, you need to set the model to call some model...like # When customizing the model, you need to set the model to call deepseek-chat model + +OPENAI_COMPATIBILIY_ModelType=deepseek-chat +OPENAI_COMPATIBILIY_API_BASE_URL=https://api.deepseek.com +OPENAI_COMPATIBILIY_API_KEY= \ No newline at end of file diff --git a/examples/omegaPRM_openR/config.yaml b/examples/omegaPRM_openR/config.yaml new file mode 100644 index 0000000000..7d59760b39 --- /dev/null +++ b/examples/omegaPRM_openR/config.yaml @@ -0,0 +1,25 @@ +input: + json_file_path: 'example_problems.json' + +output: + file_prefix: 'example' + log_file_path: 'example_processing.log' + +processing: + initial_rollouts: 20 + num_rollouts: 20 + max_iterations: 100 + +model: + model_type: "camel" + model_name: "deepseek-chat" + model_args: + max_tokens: 200 + temperature_range: [0.7, 1.0] + + +# There are 32 initial scroll execution questions to choose from. +# If the questions are too easy (all right) or too hard (all wrong), +# then OmegaPRM will not deal with them. +# So you may need a problem set of the right difficulty and a model that is not too strong, +# such as qwen's open source small model. Then the node node data may be obtained \ No newline at end of file diff --git a/examples/omegaPRM_openR/example_problems.json b/examples/omegaPRM_openR/example_problems.json new file mode 100644 index 0000000000..bfec425166 --- /dev/null +++ b/examples/omegaPRM_openR/example_problems.json @@ -0,0 +1,6 @@ +[ + { + "problem": "How many ways can we put 3 math books and 5 English books on a shelf if all the math books must stay together and all the English books must also stay together? (The math books are all different and so are the English books.)", + "final_answer": "1440" + } +] \ No newline at end of file diff --git a/examples/omegaPRM_openR/gen_data.py b/examples/omegaPRM_openR/gen_data.py new file mode 100644 index 0000000000..d412256fd7 --- /dev/null +++ b/examples/omegaPRM_openR/gen_data.py @@ -0,0 +1,146 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +import json +import logging + +import yaml +from model_utils import LM +from module import ( + Node, + calculate_mc_score, + perform_rollouts, + process_annotations, +) + + +def load_config(config_path): + """ + Load configuration from a YAML file. + + Args: + config_path (str): Path to the YAML configuration file. + + Returns: + dict: A dictionary containing the configuration. + """ + with open(config_path, 'r') as file: + return yaml.safe_load(file) + + +def load_json_file(file_path): + """ + Load data from a JSON file. + + Args: + file_path (str): Path to the JSON file. + + Returns: + list: A list of dictionaries containing the problem and final answer. + """ + with open(file_path, 'r') as file: + data = json.load(file) + return data + + +def setup_logging(log_file): + """ + Set up logging configuration to output to file and console. + + Args: + log_file (str): Path to the log file. + """ + logging.basicConfig( + filename=log_file, + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + ) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + ) + console_handler.setFormatter(formatter) + + root_logger = logging.getLogger() + root_logger.addHandler(console_handler) + + logging.getLogger("openai").setLevel(logging.ERROR) + logging.getLogger("httpx").setLevel(logging.WARNING) + + +def main(): + # Load configuration + config = load_config('config.yaml') + + # Get parameters from config + json_file_path = config['input']['json_file_path'] + log_file_path = config['output']['log_file_path'] + file_prefix = config['output']['file_prefix'] + num_rollouts = config['processing']['num_rollouts'] + initial_rollouts = config['processing']['initial_rollouts'] + max_iterations = config['processing']['max_iterations'] + + lm_model = LM( + model_type=config['model']['model_type'], + model_name=config['model']['model_name'], + num_rollouts=num_rollouts, + **config['model']['model_args'], + ) + + # Set up logging + setup_logging(log_file_path) + + # Start the process and log it + logging.info("Started processing the JSON file.") + + # Load the JSON data + data = load_json_file(json_file_path) + + # Process each problem and its final answer + for i, item in enumerate(data): + problem = item.get('problem', 'No problem found') + final_answer = item.get('final_answer', 'No answer found') + + # Log each problem and answer + logging.info(f"Processed Problem {i + 1}: {problem}") + logging.info(f"Final Answer: {final_answer}") + + # Initialize the root node and perform rollouts + nodes = [] + root_node = Node(problem, "", final_answer) + rollouts, correctness_flags = perform_rollouts( + root_node, lm_model, initial_rollouts + ) + mc_score = calculate_mc_score(root_node) + root_node.mc_score = mc_score + + nodes.append(root_node) + + # Check if further processing is needed + if 0 < sum(correctness_flags) < initial_rollouts: + print("Processing annotations ...\n") + filename = f"{file_prefix}_{i+1}_nodes_data.json" + process_annotations( + problem, nodes, lm_model, filename, max_iterations + ) + + # Log completion + logging.info("Finished processing the JSON file.") + + +if __name__ == "__main__": + main() diff --git a/examples/omegaPRM_openR/model_utils.py b/examples/omegaPRM_openR/model_utils.py new file mode 100644 index 0000000000..317d03c898 --- /dev/null +++ b/examples/omegaPRM_openR/model_utils.py @@ -0,0 +1,96 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +import os +import random + +from dotenv import load_dotenv +from tqdm import tqdm + +from camel.agents import ChatAgent +from camel.models import ModelFactory +from camel.types import ModelPlatformType + +load_dotenv() + + +class LM: + def __init__(self, model_type, model_name, num_rollouts=5, **kwargs): + self.model_type = model_type + self.model_name = model_name + self.num_rollouts = num_rollouts + self.max_tokens = kwargs.get('max_tokens', 4096) + self.temperature_range = kwargs.get('temperature_range', [0.7, 1.0]) + + if self.model_type != "camel": + raise ValueError("Only camel model type is supported") + + # Initialize camel model + self.model = ModelFactory.create( + model_platform=ModelPlatformType.OPENAI_COMPATIBLE_MODEL, + model_type=os.environ.get("OPENAI_COMPATIBILIY_ModelType"), + api_key=os.environ.get("OPENAI_COMPATIBILIY_API_KEY"), + url=os.environ.get("OPENAI_COMPATIBILIY_API_BASE_URL"), + model_config_dict={ + "temperature": random.uniform(*self.temperature_range), + "max_tokens": self.max_tokens, + }, + ) + + # Initialize chat agent + self.agent = ChatAgent( + system_message='''You are a mathematical reasoning expert who + always solves problems step by step. +For each step: +1. Write down what you're calculating +2. Show the calculation +3. Explain the result +Always show your work, even for simple calculations. +End your solution with the final numerical answer.''', + model=self.model, + message_window_size=10, + ) + + def generate(self, question, partial_answer, num_rollouts=None): + results = [] + if num_rollouts is None: + num_rollouts = self.num_rollouts + + for _ in tqdm(range(num_rollouts)): + # Update temperature for each rollout + self.model.model_config_dict["temperature"] = random.uniform( + *self.temperature_range + ) + + # Construct the prompt + if partial_answer: + prompt = f"""Problem: {question} +Current solution steps: +{partial_answer} +Continue the solution, showing all steps and calculations. +Make sure to explain each step:""" + else: + prompt = f"""Problem: {question} +Please solve this step by step, showing all calculations and + explaining each step. +Remember to: +1. Break down the problem +2. Show all calculations +3. Explain each step +4. End with the final numerical answer.""" + + # Get response from agent + response = self.agent.step(prompt) + results.append(response.msgs[0].content) + + return results diff --git a/examples/omegaPRM_openR/module.py b/examples/omegaPRM_openR/module.py new file mode 100644 index 0000000000..a73a35037f --- /dev/null +++ b/examples/omegaPRM_openR/module.py @@ -0,0 +1,223 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +import json +import math +import os +import re + +from model_utils import LM + + +class Node: + def __init__(self, question, partial_answer, correct_answer): + self.question = question + self.partial_answer = partial_answer + self.correct_answer = correct_answer + self.mc_score = None + self.visits = 0 + self.rollouts = [] + self.visited_rollouts = [] + + def add_rollout(self, result): + self.rollouts.append(result) + self.visited_rollouts.append(False) + + def increment_visits(self): + self.visits += 1 + + +# Evaluation +def check_correctness(expected_answer, generated_response): + sentences = re.split( + r'(? highest_qu_value: + highest_qu_value = qu_value + best_node = node + best_rollout_idx = idx + if best_rollout_idx != -1 and best_node is not None: + best_node.visited_rollouts[best_rollout_idx] = True + return ( + best_node, + best_node.rollouts[best_rollout_idx], + highest_qu_value, + ) + else: + return None, None, None + + +def split_text_middle(text): + text = text.strip() + mid_idx = len(text) // 2 + if text[mid_idx] != ' ': + left_space = text.rfind(' ', 0, mid_idx) + right_space = text.find(' ', mid_idx) + if left_space == -1: + split_idx = right_space + elif right_space == -1: + split_idx = left_space + else: + split_idx = ( + left_space + if (mid_idx - left_space) <= (right_space - mid_idx) + else right_space + ) + else: + split_idx = mid_idx + part1 = text[:split_idx].strip() + part2 = text[split_idx:].strip() + return part1, part2 + + +def locate_error(node, rollout, model): + current_span = rollout + previous_text = "" + nodes_to_expand = [] + leaf_nodes = [] + while True: + if len(current_span.split()) < 2: + break + left_part, right_part = split_text_middle(current_span) + print("----") + print(" Left:", left_part) + print(" Right:", right_part) + new_node = Node( + node.question, previous_text + left_part, node.correct_answer + ) + perform_rollouts(new_node, model) + mc_score = calculate_mc_score(new_node) + new_node.mc_score = mc_score + if mc_score == 1: + break + elif mc_score > 0: + current_span = right_part + previous_text += left_part + nodes_to_expand.append(new_node) + else: + current_span = left_part + leaf_nodes.append(new_node) + print("----") + return nodes_to_expand, leaf_nodes + + +def compute_q_value( + rollout_text, mc_score, alpha=0.5, beta=0.9, max_length=500 +): + part1 = alpha ** (1 - mc_score) + part2 = beta ** (len(rollout_text) / max_length) + return part1 * part2 + + +def compute_u_value(node, all_nodes, exploration_param=0.125): + total_visits = sum(n.visits for n in all_nodes) + numerator = math.sqrt(total_visits) + denominator = 1 + node.visits + return exploration_param * (numerator / denominator) + + +def process_annotations( + question, nodes, model: LM, filename='nodes_data.json', max_iterations=100 +): + print("++++++") + iteration = 0 + leaf_nodes = [] + while True: + node, rollout, max_qu = select_best_node(nodes) + if node is not None and node.partial_answer != '': + new_entry = { + "question": question, + "partial_answer": node.partial_answer, + "mc_score": node.mc_score, + } + append_to_json(filename, new_entry) + iteration += 1 + if iteration > max_iterations: + break + if node is None: + break + print() + print("[Selected Node]") + print(node) + print(" Rollout:", rollout, " || QU Value:", max_qu) + node.increment_visits() + expanded_nodes, leaves = locate_error(node, rollout, model) + if not expanded_nodes: + continue + nodes.extend( + n + for n in expanded_nodes + if n is not None and n.partial_answer != '' + ) + leaf_nodes.extend(leaves) + for leaf_node in leaf_nodes: + new_entry = { + "question": question, + "partial_answer": leaf_node.partial_answer, + "mc_score": leaf_node.mc_score, + } + append_to_json(filename, new_entry) + print("++++++") + + +# Utils +def append_to_json(filename, data_entry): + if os.path.exists(filename): + with open(filename, 'r') as file: + data = json.load(file) + else: + data = [] + data.append(data_entry) + with open(filename, 'w') as file: + json.dump(data, file, indent=4) + print(f"Data appended to {filename}")