From 12ec2ec3690636bb09e443ee13c12e7a595f92e3 Mon Sep 17 00:00:00 2001 From: zjrwtx <3038880699@qq.com> Date: Fri, 13 Dec 2024 19:30:35 +0800 Subject: [PATCH 1/3] add examples --- examples/omegaPRM_openR/.env.example | 5 + examples/omegaPRM_openR/config.yaml | 25 ++ examples/omegaPRM_openR/example_problems.json | 6 + examples/omegaPRM_openR/gen_data.py | 146 ++++++++++++ examples/omegaPRM_openR/model_utils.py | 96 ++++++++ examples/omegaPRM_openR/module.py | 223 ++++++++++++++++++ 6 files changed, 501 insertions(+) create mode 100644 examples/omegaPRM_openR/.env.example create mode 100644 examples/omegaPRM_openR/config.yaml create mode 100644 examples/omegaPRM_openR/example_problems.json create mode 100644 examples/omegaPRM_openR/gen_data.py create mode 100644 examples/omegaPRM_openR/model_utils.py create mode 100644 examples/omegaPRM_openR/module.py diff --git a/examples/omegaPRM_openR/.env.example b/examples/omegaPRM_openR/.env.example new file mode 100644 index 0000000000..49f6eb1d68 --- /dev/null +++ b/examples/omegaPRM_openR/.env.example @@ -0,0 +1,5 @@ +# When customizing the model, you need to set the model to call some model...like # When customizing the model, you need to set the model to call deepseek-chat model + +OPENAI_COMPATIBILIY_ModelType=deepseek-chat +OPENAI_COMPATIBILIY_API_BASE_URL=https://api.deepseek.com +OPENAI_COMPATIBILIY_API_KEY= \ No newline at end of file diff --git a/examples/omegaPRM_openR/config.yaml b/examples/omegaPRM_openR/config.yaml new file mode 100644 index 0000000000..7d59760b39 --- /dev/null +++ b/examples/omegaPRM_openR/config.yaml @@ -0,0 +1,25 @@ +input: + json_file_path: 'example_problems.json' + +output: + file_prefix: 'example' + log_file_path: 'example_processing.log' + +processing: + initial_rollouts: 20 + num_rollouts: 20 + max_iterations: 100 + +model: + model_type: "camel" + model_name: "deepseek-chat" + model_args: + max_tokens: 200 + temperature_range: [0.7, 1.0] + + +# There are 32 initial scroll execution questions to choose from. +# If the questions are too easy (all right) or too hard (all wrong), +# then OmegaPRM will not deal with them. +# So you may need a problem set of the right difficulty and a model that is not too strong, +# such as qwen's open source small model. Then the node node data may be obtained \ No newline at end of file diff --git a/examples/omegaPRM_openR/example_problems.json b/examples/omegaPRM_openR/example_problems.json new file mode 100644 index 0000000000..bfec425166 --- /dev/null +++ b/examples/omegaPRM_openR/example_problems.json @@ -0,0 +1,6 @@ +[ + { + "problem": "How many ways can we put 3 math books and 5 English books on a shelf if all the math books must stay together and all the English books must also stay together? (The math books are all different and so are the English books.)", + "final_answer": "1440" + } +] \ No newline at end of file diff --git a/examples/omegaPRM_openR/gen_data.py b/examples/omegaPRM_openR/gen_data.py new file mode 100644 index 0000000000..d412256fd7 --- /dev/null +++ b/examples/omegaPRM_openR/gen_data.py @@ -0,0 +1,146 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +import json +import logging + +import yaml +from model_utils import LM +from module import ( + Node, + calculate_mc_score, + perform_rollouts, + process_annotations, +) + + +def load_config(config_path): + """ + Load configuration from a YAML file. + + Args: + config_path (str): Path to the YAML configuration file. + + Returns: + dict: A dictionary containing the configuration. + """ + with open(config_path, 'r') as file: + return yaml.safe_load(file) + + +def load_json_file(file_path): + """ + Load data from a JSON file. + + Args: + file_path (str): Path to the JSON file. + + Returns: + list: A list of dictionaries containing the problem and final answer. + """ + with open(file_path, 'r') as file: + data = json.load(file) + return data + + +def setup_logging(log_file): + """ + Set up logging configuration to output to file and console. + + Args: + log_file (str): Path to the log file. + """ + logging.basicConfig( + filename=log_file, + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + ) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + ) + console_handler.setFormatter(formatter) + + root_logger = logging.getLogger() + root_logger.addHandler(console_handler) + + logging.getLogger("openai").setLevel(logging.ERROR) + logging.getLogger("httpx").setLevel(logging.WARNING) + + +def main(): + # Load configuration + config = load_config('config.yaml') + + # Get parameters from config + json_file_path = config['input']['json_file_path'] + log_file_path = config['output']['log_file_path'] + file_prefix = config['output']['file_prefix'] + num_rollouts = config['processing']['num_rollouts'] + initial_rollouts = config['processing']['initial_rollouts'] + max_iterations = config['processing']['max_iterations'] + + lm_model = LM( + model_type=config['model']['model_type'], + model_name=config['model']['model_name'], + num_rollouts=num_rollouts, + **config['model']['model_args'], + ) + + # Set up logging + setup_logging(log_file_path) + + # Start the process and log it + logging.info("Started processing the JSON file.") + + # Load the JSON data + data = load_json_file(json_file_path) + + # Process each problem and its final answer + for i, item in enumerate(data): + problem = item.get('problem', 'No problem found') + final_answer = item.get('final_answer', 'No answer found') + + # Log each problem and answer + logging.info(f"Processed Problem {i + 1}: {problem}") + logging.info(f"Final Answer: {final_answer}") + + # Initialize the root node and perform rollouts + nodes = [] + root_node = Node(problem, "", final_answer) + rollouts, correctness_flags = perform_rollouts( + root_node, lm_model, initial_rollouts + ) + mc_score = calculate_mc_score(root_node) + root_node.mc_score = mc_score + + nodes.append(root_node) + + # Check if further processing is needed + if 0 < sum(correctness_flags) < initial_rollouts: + print("Processing annotations ...\n") + filename = f"{file_prefix}_{i+1}_nodes_data.json" + process_annotations( + problem, nodes, lm_model, filename, max_iterations + ) + + # Log completion + logging.info("Finished processing the JSON file.") + + +if __name__ == "__main__": + main() diff --git a/examples/omegaPRM_openR/model_utils.py b/examples/omegaPRM_openR/model_utils.py new file mode 100644 index 0000000000..317d03c898 --- /dev/null +++ b/examples/omegaPRM_openR/model_utils.py @@ -0,0 +1,96 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +import os +import random + +from dotenv import load_dotenv +from tqdm import tqdm + +from camel.agents import ChatAgent +from camel.models import ModelFactory +from camel.types import ModelPlatformType + +load_dotenv() + + +class LM: + def __init__(self, model_type, model_name, num_rollouts=5, **kwargs): + self.model_type = model_type + self.model_name = model_name + self.num_rollouts = num_rollouts + self.max_tokens = kwargs.get('max_tokens', 4096) + self.temperature_range = kwargs.get('temperature_range', [0.7, 1.0]) + + if self.model_type != "camel": + raise ValueError("Only camel model type is supported") + + # Initialize camel model + self.model = ModelFactory.create( + model_platform=ModelPlatformType.OPENAI_COMPATIBLE_MODEL, + model_type=os.environ.get("OPENAI_COMPATIBILIY_ModelType"), + api_key=os.environ.get("OPENAI_COMPATIBILIY_API_KEY"), + url=os.environ.get("OPENAI_COMPATIBILIY_API_BASE_URL"), + model_config_dict={ + "temperature": random.uniform(*self.temperature_range), + "max_tokens": self.max_tokens, + }, + ) + + # Initialize chat agent + self.agent = ChatAgent( + system_message='''You are a mathematical reasoning expert who + always solves problems step by step. +For each step: +1. Write down what you're calculating +2. Show the calculation +3. Explain the result +Always show your work, even for simple calculations. +End your solution with the final numerical answer.''', + model=self.model, + message_window_size=10, + ) + + def generate(self, question, partial_answer, num_rollouts=None): + results = [] + if num_rollouts is None: + num_rollouts = self.num_rollouts + + for _ in tqdm(range(num_rollouts)): + # Update temperature for each rollout + self.model.model_config_dict["temperature"] = random.uniform( + *self.temperature_range + ) + + # Construct the prompt + if partial_answer: + prompt = f"""Problem: {question} +Current solution steps: +{partial_answer} +Continue the solution, showing all steps and calculations. +Make sure to explain each step:""" + else: + prompt = f"""Problem: {question} +Please solve this step by step, showing all calculations and + explaining each step. +Remember to: +1. Break down the problem +2. Show all calculations +3. Explain each step +4. End with the final numerical answer.""" + + # Get response from agent + response = self.agent.step(prompt) + results.append(response.msgs[0].content) + + return results diff --git a/examples/omegaPRM_openR/module.py b/examples/omegaPRM_openR/module.py new file mode 100644 index 0000000000..a73a35037f --- /dev/null +++ b/examples/omegaPRM_openR/module.py @@ -0,0 +1,223 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +import json +import math +import os +import re + +from model_utils import LM + + +class Node: + def __init__(self, question, partial_answer, correct_answer): + self.question = question + self.partial_answer = partial_answer + self.correct_answer = correct_answer + self.mc_score = None + self.visits = 0 + self.rollouts = [] + self.visited_rollouts = [] + + def add_rollout(self, result): + self.rollouts.append(result) + self.visited_rollouts.append(False) + + def increment_visits(self): + self.visits += 1 + + +# Evaluation +def check_correctness(expected_answer, generated_response): + sentences = re.split( + r'(? highest_qu_value: + highest_qu_value = qu_value + best_node = node + best_rollout_idx = idx + if best_rollout_idx != -1 and best_node is not None: + best_node.visited_rollouts[best_rollout_idx] = True + return ( + best_node, + best_node.rollouts[best_rollout_idx], + highest_qu_value, + ) + else: + return None, None, None + + +def split_text_middle(text): + text = text.strip() + mid_idx = len(text) // 2 + if text[mid_idx] != ' ': + left_space = text.rfind(' ', 0, mid_idx) + right_space = text.find(' ', mid_idx) + if left_space == -1: + split_idx = right_space + elif right_space == -1: + split_idx = left_space + else: + split_idx = ( + left_space + if (mid_idx - left_space) <= (right_space - mid_idx) + else right_space + ) + else: + split_idx = mid_idx + part1 = text[:split_idx].strip() + part2 = text[split_idx:].strip() + return part1, part2 + + +def locate_error(node, rollout, model): + current_span = rollout + previous_text = "" + nodes_to_expand = [] + leaf_nodes = [] + while True: + if len(current_span.split()) < 2: + break + left_part, right_part = split_text_middle(current_span) + print("----") + print(" Left:", left_part) + print(" Right:", right_part) + new_node = Node( + node.question, previous_text + left_part, node.correct_answer + ) + perform_rollouts(new_node, model) + mc_score = calculate_mc_score(new_node) + new_node.mc_score = mc_score + if mc_score == 1: + break + elif mc_score > 0: + current_span = right_part + previous_text += left_part + nodes_to_expand.append(new_node) + else: + current_span = left_part + leaf_nodes.append(new_node) + print("----") + return nodes_to_expand, leaf_nodes + + +def compute_q_value( + rollout_text, mc_score, alpha=0.5, beta=0.9, max_length=500 +): + part1 = alpha ** (1 - mc_score) + part2 = beta ** (len(rollout_text) / max_length) + return part1 * part2 + + +def compute_u_value(node, all_nodes, exploration_param=0.125): + total_visits = sum(n.visits for n in all_nodes) + numerator = math.sqrt(total_visits) + denominator = 1 + node.visits + return exploration_param * (numerator / denominator) + + +def process_annotations( + question, nodes, model: LM, filename='nodes_data.json', max_iterations=100 +): + print("++++++") + iteration = 0 + leaf_nodes = [] + while True: + node, rollout, max_qu = select_best_node(nodes) + if node is not None and node.partial_answer != '': + new_entry = { + "question": question, + "partial_answer": node.partial_answer, + "mc_score": node.mc_score, + } + append_to_json(filename, new_entry) + iteration += 1 + if iteration > max_iterations: + break + if node is None: + break + print() + print("[Selected Node]") + print(node) + print(" Rollout:", rollout, " || QU Value:", max_qu) + node.increment_visits() + expanded_nodes, leaves = locate_error(node, rollout, model) + if not expanded_nodes: + continue + nodes.extend( + n + for n in expanded_nodes + if n is not None and n.partial_answer != '' + ) + leaf_nodes.extend(leaves) + for leaf_node in leaf_nodes: + new_entry = { + "question": question, + "partial_answer": leaf_node.partial_answer, + "mc_score": leaf_node.mc_score, + } + append_to_json(filename, new_entry) + print("++++++") + + +# Utils +def append_to_json(filename, data_entry): + if os.path.exists(filename): + with open(filename, 'r') as file: + data = json.load(file) + else: + data = [] + data.append(data_entry) + with open(filename, 'w') as file: + json.dump(data, file, indent=4) + print(f"Data appended to {filename}") From dcd18b792e909bc343b983a26c16264b2e6ac865 Mon Sep 17 00:00:00 2001 From: zjrwtx <3038880699@qq.com> Date: Fri, 13 Dec 2024 21:06:21 +0800 Subject: [PATCH 2/3] refator to v2 version --- examples/omegaPRM_openR/.env | 5 + .../{.env.example => .env copy.example} | 0 examples/omegaPRM_openR/config.yaml | 25 -- examples/omegaPRM_openR/example_problems.json | 18 +- examples/omegaPRM_openR/gen_data.py | 101 ++++---- examples/omegaPRM_openR/model_utils.py | 25 ++ examples/omegaPRM_openR/module.py | 68 +++-- examples/omegaPRM_openR/omegaprm_v2.py | 240 ++++++++++++++++++ examples/omegaPRM_openR/search_tree.py | 146 +++++++++++ 9 files changed, 514 insertions(+), 114 deletions(-) create mode 100644 examples/omegaPRM_openR/.env rename examples/omegaPRM_openR/{.env.example => .env copy.example} (100%) delete mode 100644 examples/omegaPRM_openR/config.yaml create mode 100644 examples/omegaPRM_openR/omegaprm_v2.py create mode 100644 examples/omegaPRM_openR/search_tree.py diff --git a/examples/omegaPRM_openR/.env b/examples/omegaPRM_openR/.env new file mode 100644 index 0000000000..840c6eacf8 --- /dev/null +++ b/examples/omegaPRM_openR/.env @@ -0,0 +1,5 @@ +# When customizing the model, you need to set the model to call some model...like # When customizing the model, you need to set the model to call deepseek-chat model + +OPENAI_COMPATIBILIY_ModelType=deepseek-chat +OPENAI_COMPATIBILIY_API_BASE_URL=https://api.deepseek.com +OPENAI_COMPATIBILIY_API_KEY=sk-89b57b7be04a4bb886e97115edbfe4f3 \ No newline at end of file diff --git a/examples/omegaPRM_openR/.env.example b/examples/omegaPRM_openR/.env copy.example similarity index 100% rename from examples/omegaPRM_openR/.env.example rename to examples/omegaPRM_openR/.env copy.example diff --git a/examples/omegaPRM_openR/config.yaml b/examples/omegaPRM_openR/config.yaml deleted file mode 100644 index 7d59760b39..0000000000 --- a/examples/omegaPRM_openR/config.yaml +++ /dev/null @@ -1,25 +0,0 @@ -input: - json_file_path: 'example_problems.json' - -output: - file_prefix: 'example' - log_file_path: 'example_processing.log' - -processing: - initial_rollouts: 20 - num_rollouts: 20 - max_iterations: 100 - -model: - model_type: "camel" - model_name: "deepseek-chat" - model_args: - max_tokens: 200 - temperature_range: [0.7, 1.0] - - -# There are 32 initial scroll execution questions to choose from. -# If the questions are too easy (all right) or too hard (all wrong), -# then OmegaPRM will not deal with them. -# So you may need a problem set of the right difficulty and a model that is not too strong, -# such as qwen's open source small model. Then the node node data may be obtained \ No newline at end of file diff --git a/examples/omegaPRM_openR/example_problems.json b/examples/omegaPRM_openR/example_problems.json index bfec425166..4c4dabd11b 100644 --- a/examples/omegaPRM_openR/example_problems.json +++ b/examples/omegaPRM_openR/example_problems.json @@ -1,6 +1,18 @@ [ { - "problem": "How many ways can we put 3 math books and 5 English books on a shelf if all the math books must stay together and all the English books must also stay together? (The math books are all different and so are the English books.)", - "final_answer": "1440" + "problem": "A bag contains 6 red balls, 4 blue balls, and 5 green balls. If you draw 3 balls without replacement, what is the probability of drawing exactly 2 red balls and 1 blue ball? Express your answer as a fraction in simplest form.", + "final_answer": "1/11" + }, + { + "problem": "In how many ways can 8 different books be arranged on a shelf if 3 specific books must always be next to each other (in any order among themselves)?", + "final_answer": "720" + }, + { + "problem": "A palindrome is a number that reads the same forwards and backwards. How many palindromes are there between 100 and 999?", + "final_answer": "90" + }, + { + "problem": "Three fair coins are tossed. What is the probability of getting at least two heads? Express your answer as a fraction in simplest form.", + "final_answer": "1/2" } -] \ No newline at end of file +] \ No newline at end of file diff --git a/examples/omegaPRM_openR/gen_data.py b/examples/omegaPRM_openR/gen_data.py index d412256fd7..15ea3f7cf9 100644 --- a/examples/omegaPRM_openR/gen_data.py +++ b/examples/omegaPRM_openR/gen_data.py @@ -83,63 +83,68 @@ def setup_logging(log_file): def main(): - # Load configuration - config = load_config('config.yaml') - - # Get parameters from config - json_file_path = config['input']['json_file_path'] - log_file_path = config['output']['log_file_path'] - file_prefix = config['output']['file_prefix'] - num_rollouts = config['processing']['num_rollouts'] - initial_rollouts = config['processing']['initial_rollouts'] - max_iterations = config['processing']['max_iterations'] - - lm_model = LM( - model_type=config['model']['model_type'], - model_name=config['model']['model_name'], - num_rollouts=num_rollouts, - **config['model']['model_args'], - ) + # Direct configuration instead of loading from yaml + config = { + 'input': {'json_file_path': 'example_problems.json'}, + 'output': { + 'file_prefix': 'example', + 'log_file_path': 'example_processing.log', + }, + 'processing': { + 'initial_rollouts': 30, # 增加初始rollouts数量 + 'num_rollouts': 25, # 增加每次迭代的rollouts数量 + 'max_iterations': 150, # 增加最大迭代次数 + }, + 'model': { + 'model_type': 'camel', + 'model_name': 'deepseek-chat', + 'model_args': { + 'max_tokens': 300, # 增加最大token数 + 'temperature_range': [0.6, 0.9], # 调整温度范围,增加多样性 + }, + }, + } # Set up logging - setup_logging(log_file_path) + setup_logging(config['output']['log_file_path']) + logging.info("Starting data generation process...") - # Start the process and log it - logging.info("Started processing the JSON file.") + # Load problems from JSON file + problems = load_json_file(config['input']['json_file_path']) + logging.info(f"Loaded {len(problems)} problems") - # Load the JSON data - data = load_json_file(json_file_path) + # Initialize the language model + model = LM( + model_type=config['model']['model_type'], + model_name=config['model']['model_name'], + model_args=config['model']['model_args'], + ) - # Process each problem and its final answer - for i, item in enumerate(data): - problem = item.get('problem', 'No problem found') - final_answer = item.get('final_answer', 'No answer found') + # Process each problem + for problem in problems: + question = problem['problem'] + final_answer = problem['final_answer'] - # Log each problem and answer - logging.info(f"Processed Problem {i + 1}: {problem}") - logging.info(f"Final Answer: {final_answer}") + # Create initial node + initial_node = Node(question, "", final_answer) + nodes = [initial_node] - # Initialize the root node and perform rollouts - nodes = [] - root_node = Node(problem, "", final_answer) - rollouts, correctness_flags = perform_rollouts( - root_node, lm_model, initial_rollouts + # Perform initial rollouts + perform_rollouts( + initial_node, model, config['processing']['initial_rollouts'] + ) + calculate_mc_score(initial_node) + + # Process annotations + process_annotations( + question, + nodes, + model, + filename=f"{config['output']['file_prefix']}_nodes_data.json", + max_iterations=config['processing']['max_iterations'], ) - mc_score = calculate_mc_score(root_node) - root_node.mc_score = mc_score - - nodes.append(root_node) - - # Check if further processing is needed - if 0 < sum(correctness_flags) < initial_rollouts: - print("Processing annotations ...\n") - filename = f"{file_prefix}_{i+1}_nodes_data.json" - process_annotations( - problem, nodes, lm_model, filename, max_iterations - ) - # Log completion - logging.info("Finished processing the JSON file.") + logging.info("Data generation process completed") if __name__ == "__main__": diff --git a/examples/omegaPRM_openR/model_utils.py b/examples/omegaPRM_openR/model_utils.py index 317d03c898..55890314c6 100644 --- a/examples/omegaPRM_openR/model_utils.py +++ b/examples/omegaPRM_openR/model_utils.py @@ -13,6 +13,7 @@ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= import os import random +from typing import List, Optional from dotenv import load_dotenv from tqdm import tqdm @@ -94,3 +95,27 @@ def generate(self, question, partial_answer, num_rollouts=None): results.append(response.msgs[0].content) return results + + def generate_rollouts( + self, prompt: str, num_copies: Optional[int] = None + ) -> List[str]: + """Generate multiple rollouts for a given prompt. + + Args: + prompt (str): The input prompt to generate responses for + num_copies (Optional[int], optional): Number of copies to generate. + Defaults to None. + + Returns: + List[str]: List of generated responses + """ + if num_copies is None: + num_copies = 1 + + rollouts = [] + for _ in range(num_copies): + response = self.generate(prompt, "") + if response: + rollouts.append(response[0]) + + return rollouts diff --git a/examples/omegaPRM_openR/module.py b/examples/omegaPRM_openR/module.py index a73a35037f..337d9f74f7 100644 --- a/examples/omegaPRM_openR/module.py +++ b/examples/omegaPRM_openR/module.py @@ -17,6 +17,7 @@ import re from model_utils import LM +from omegaprm_v2 import OmegaPRMV2 class Node: @@ -169,45 +170,36 @@ def compute_u_value(node, all_nodes, exploration_param=0.125): def process_annotations( question, nodes, model: LM, filename='nodes_data.json', max_iterations=100 ): - print("++++++") - iteration = 0 - leaf_nodes = [] - while True: - node, rollout, max_qu = select_best_node(nodes) - if node is not None and node.partial_answer != '': - new_entry = { - "question": question, - "partial_answer": node.partial_answer, - "mc_score": node.mc_score, + """Process annotations using OmegaPRM v2.""" + # Initialize OmegaPRM v2 + omegaprm = OmegaPRMV2( + model=model, + c_puct=0.2, + alpha=0.5, + beta=0.9, + L=500, + k=5, + N=max_iterations, + rollout_budget=1000, + save_data_tree=True, + ) + + # Process each node + for node in nodes: + collected_data = omegaprm.run(node.question, node.correct_answer) + + # Save collected data + for data in collected_data: + data_entry = { + 'question': node.question, + 'correct_answer': node.correct_answer, + 'iteration': data['iteration'], + 'total_rollouts': data['total_rollouts'], + 'tree_structure': data['tree_structure'], } - append_to_json(filename, new_entry) - iteration += 1 - if iteration > max_iterations: - break - if node is None: - break - print() - print("[Selected Node]") - print(node) - print(" Rollout:", rollout, " || QU Value:", max_qu) - node.increment_visits() - expanded_nodes, leaves = locate_error(node, rollout, model) - if not expanded_nodes: - continue - nodes.extend( - n - for n in expanded_nodes - if n is not None and n.partial_answer != '' - ) - leaf_nodes.extend(leaves) - for leaf_node in leaf_nodes: - new_entry = { - "question": question, - "partial_answer": leaf_node.partial_answer, - "mc_score": leaf_node.mc_score, - } - append_to_json(filename, new_entry) - print("++++++") + append_to_json(filename, data_entry) + + return nodes # Utils diff --git a/examples/omegaPRM_openR/omegaprm_v2.py b/examples/omegaPRM_openR/omegaprm_v2.py new file mode 100644 index 0000000000..190740179a --- /dev/null +++ b/examples/omegaPRM_openR/omegaprm_v2.py @@ -0,0 +1,240 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +"""OmegaPRM v2 implementation with enhanced search capabilities.""" + +import math +import re +from typing import Any, Dict, List, Optional, Tuple + +from model_utils import LM +from search_tree import CandidatePool, SearchTree, State + + +def separate_steps(steps: List[str], mode: str = 'join') -> Any: + """Helper function to separate reasoning steps.""" + delimiter = "\n\n" + if mode == 'join': + if not isinstance(steps, list): + raise TypeError( + "For 'join' mode, 'steps' must be a list of strings." + ) + return delimiter.join(steps) + elif mode == 'split': + if not isinstance(steps, str): + raise TypeError("For 'split' mode, 'steps' must be a string.") + return steps.split(delimiter) + else: + raise ValueError("Mode should be either 'join' or 'split'.") + + +def check_correctness(generated_response: str, expected_answer: str) -> bool: + """Helper function to check correctness of a generated response.""" + sentences = re.split( + r'(? 0 else 0 + + def compute_Q(self, state: State, rollout: str) -> float: + """Compute Q(s, r) value.""" + if rollout in state.Q: + return state.Q[rollout] + + # Count words in rollout + word_count = len(rollout.split()) + length_penalty = math.pow(self.beta, word_count / self.L) + mc_weight = ( + math.pow(self.alpha, 1 - state.MC) if state.MC is not None else 1 + ) + + q_value = mc_weight * length_penalty + state.Q[rollout] = q_value + return q_value + + def compute_U(self, state: State) -> float: + """Compute U(s) value.""" + total_visits = ( + sum(child.N for child in state.children) if state.children else 0 + ) + return self.c_puct * math.sqrt(total_visits) / (1 + state.N) + + def compute_selection_score(self, state: State, rollout: str) -> float: + """Compute selection score: Score(s, r) = Q(s, r) + U(s).""" + return self.compute_Q(state, rollout) + self.compute_U(state) + + def selection_phase(self) -> Tuple[Optional[State], Optional[str]]: + """Select (state, rollout) with highest score from candidate pool.""" + return self.C.pop() + + def add_correct_rollout_to_tree(self, parent_state: State, rollout: str): + """Add correct rollout to the tree as a child of parent_state.""" + new_state = State(rollout, parent_state) + parent_state.children.append(new_state) + self.T.add_state(new_state) + return new_state + + def binary_search_incorrect_step( + self, s_ast: State, steps: List[str], left: int, right: int + ): + """Recursively perform binary search to find incorrect steps.""" + if left > right: + return + + mid = (left + right) // 2 + partial_solution = separate_steps(steps[: mid + 1]) + full_solution = s_ast.get_full_solution() + partial_solution + + # Check if this partial solution is correct + if check_correctness(full_solution, self.expected_answer): + # The error must be in the latter half + new_state = self.add_correct_rollout_to_tree( + s_ast, partial_solution + ) + self.binary_search_incorrect_step( + new_state, steps[mid + 1 :], 0, len(steps[mid + 1 :]) - 1 + ) + else: + # The error must be in this half + self.binary_search_incorrect_step(s_ast, steps, left, mid - 1) + + def expansion_phase_binary_search(self, parent_state: State, rollout: str): + """Expansion phase using binary search to find correct parts.""" + steps = separate_steps(rollout, mode='split') + self.binary_search_incorrect_step( + parent_state, steps, 0, len(steps) - 1 + ) + + def maintenance_phase(self, state: State): + """Update statistics and candidate pool for incorrect rollouts.""" + state.N += 1 + + # Re-compute selection scores for all incorrect rollouts + for rollout in state.incorrect_rollouts: + score = self.compute_selection_score(state, rollout) + self.C.add_or_update(state, rollout, score) + + def run(self, question: str, answer: str) -> List[Dict[str, Any]]: + """Execute the OmegaPRM algorithm.""" + self.reset() + self.expected_answer = answer + + # Initialize root state + root_state = State(question) + self.T.add_state(root_state) + self.monte_carlo_estimation(root_state) + + collected_data = [] + while self.n < self.N and self.total_rollouts < self.rollout_budget: + # Selection phase + s_ast, r_ast = self.selection_phase() + if s_ast is None or r_ast is None: + break + + # Expansion phase + self.expansion_phase_binary_search(s_ast, r_ast) + + # Maintenance phase + self.maintenance_phase(s_ast) + + # Collect data if enabled + if self.save_data_tree: + collected_data.append( + { + 'iteration': self.n, + 'total_rollouts': self.total_rollouts, + 'tree_structure': self.T.root.get_text_with_labels() + if self.T.root + else None, + } + ) + + self.n += 1 + + return collected_data diff --git a/examples/omegaPRM_openR/search_tree.py b/examples/omegaPRM_openR/search_tree.py new file mode 100644 index 0000000000..45f332ea70 --- /dev/null +++ b/examples/omegaPRM_openR/search_tree.py @@ -0,0 +1,146 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +"""Search tree and state management for OmegaPRM.""" + +import heapq +import itertools +from typing import Any, Dict, List, Optional, Tuple + + +class State: + """Represents a state in the search tree.""" + + def __init__(self, solution_prefix: str, parent: Optional['State'] = None): + self.solution_prefix = ( + solution_prefix # Solution prefix as a single string + ) + self.parent = parent # Reference to the parent state + self.N = 0 # Visit count (number of times selected) + self.total_rollouts = ( + 0 # Total number of rollouts generated from this state + ) + self.correct_rollouts = 0 # Number of correct rollouts + self.MC: Optional[float] = None # Monte Carlo estimation (c/k) + self.Q: Dict[ + str, float + ] = {} # Q(s, r): estimated value for each rollout + self.R: List[str] = [] # Set of all rollouts from this state + self.incorrect_rollouts: List[str] = [] # List of incorrect rollouts + self.children: List['State'] = [] # List of child states + + def add_rollout(self, rollout: str): + """Add a correct rollout to this state.""" + if rollout not in self.R: + self.R.append(rollout) + self.total_rollouts += 1 + self.correct_rollouts += 1 + + def add_incorrect_rollout(self, rollout: str): + """Add an incorrect rollout to this state.""" + if rollout not in self.incorrect_rollouts: + self.incorrect_rollouts.append(rollout) + self.total_rollouts += 1 + + def get_full_solution(self) -> str: + """Return the full solution by concatenating + all parent solution prefixes.""" + if self.parent is None: + return self.solution_prefix + return self.parent.get_full_solution() + self.solution_prefix + + def get_new_text(self) -> str: + """Return the new text added at this node compared to the parent.""" + return self.solution_prefix + + def get_text_with_labels(self) -> Dict[str, Any]: + """Return a nested dictionary with text and MC values.""" + result = { + 'text': self.get_new_text(), + 'mc_value': self.MC, + 'children': [], + } + for child in self.children: + result['children'].append(child.get_text_with_labels()) + return result + + +class SearchTree: + """Represents the search tree for OmegaPRM.""" + + def __init__(self): + self.root: Optional[State] = None + self.nodes: List[State] = [] # List of all states + + def add_state(self, state: State): + """Add a new state to the search tree.""" + self.nodes.append(state) + if self.root is None: + self.root = state + + +class CandidatePool: + """Priority queue with update capability for managing candidate states.""" + + def __init__(self): + self.heap: List[ + Tuple[float, int] + ] = [] # Heap of (-priority, unique_id) + self.entry_finder: Dict[ + int, Tuple[float, int] + ] = {} # Maps unique_id to (-priority, unique_id) + self.counter = itertools.count() # Unique sequence count + self.id_to_rollout: Dict[ + int, Tuple[State, str] + ] = {} # Maps unique_id to (state, rollout) + self.latest_id_per_rollout: Dict[ + Tuple[int, str], int + ] = {} # Maps (state_id, rollout) to unique_id + + def add_or_update(self, state: State, rollout: str, priority: float): + """Add a new rollout or update the priority of an existing rollout.""" + state_id = id(state) + rollout_key = (state_id, rollout) + + # Remove previous entry if it exists + if rollout_key in self.latest_id_per_rollout: + old_id = self.latest_id_per_rollout[rollout_key] + if old_id in self.entry_finder: + del self.entry_finder[old_id] + del self.id_to_rollout[old_id] + + # Add new entry + unique_id = next(self.counter) + entry = (-priority, unique_id) + self.entry_finder[unique_id] = entry + self.id_to_rollout[unique_id] = (state, rollout) + self.latest_id_per_rollout[rollout_key] = unique_id + heapq.heappush(self.heap, entry) + + def pop(self) -> Tuple[Optional[State], Optional[str]]: + """Pop the rollout with the highest priority.""" + while self.heap: + neg_priority, unique_id = heapq.heappop(self.heap) + if unique_id in self.entry_finder: + del self.entry_finder[unique_id] + state, rollout = self.id_to_rollout[unique_id] + del self.id_to_rollout[unique_id] + rollout_key = (id(state), rollout) + if self.latest_id_per_rollout.get(rollout_key) == unique_id: + del self.latest_id_per_rollout[rollout_key] + return state, rollout + return None, None + + def is_empty(self) -> bool: + """Check if the candidate pool is empty.""" + return len(self.entry_finder) == 0 From 60d95d37ae91268e09b77bdcd8c30ddaf9a59ac9 Mon Sep 17 00:00:00 2001 From: zjrwtx <3038880699@qq.com> Date: Fri, 13 Dec 2024 21:31:32 +0800 Subject: [PATCH 3/3] ready to review the example --- examples/omegaPRM_openR/.env | 5 ----- examples/omegaPRM_openR/{.env copy.example => .env.example} | 0 examples/omegaPRM_openR/README.md | 5 +++++ 3 files changed, 5 insertions(+), 5 deletions(-) delete mode 100644 examples/omegaPRM_openR/.env rename examples/omegaPRM_openR/{.env copy.example => .env.example} (100%) create mode 100644 examples/omegaPRM_openR/README.md diff --git a/examples/omegaPRM_openR/.env b/examples/omegaPRM_openR/.env deleted file mode 100644 index 840c6eacf8..0000000000 --- a/examples/omegaPRM_openR/.env +++ /dev/null @@ -1,5 +0,0 @@ -# When customizing the model, you need to set the model to call some model...like # When customizing the model, you need to set the model to call deepseek-chat model - -OPENAI_COMPATIBILIY_ModelType=deepseek-chat -OPENAI_COMPATIBILIY_API_BASE_URL=https://api.deepseek.com -OPENAI_COMPATIBILIY_API_KEY=sk-89b57b7be04a4bb886e97115edbfe4f3 \ No newline at end of file diff --git a/examples/omegaPRM_openR/.env copy.example b/examples/omegaPRM_openR/.env.example similarity index 100% rename from examples/omegaPRM_openR/.env copy.example rename to examples/omegaPRM_openR/.env.example diff --git a/examples/omegaPRM_openR/README.md b/examples/omegaPRM_openR/README.md new file mode 100644 index 0000000000..081c38c0d9 --- /dev/null +++ b/examples/omegaPRM_openR/README.md @@ -0,0 +1,5 @@ +# OmegaPRM OpenR + +This demo code refers to [OpenR](https://github.com/openreasoner/openr), which is released under the MIT License. + +