-
Notifications
You must be signed in to change notification settings - Fork 731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat:omegaPRM reproduced by openR: Process-supervision Data Generation(PRM) #1280
base: master
Are you sure you want to change the base?
Changes from 3 commits
12ec2ec
dcd18b7
60d95d3
59293d1
b57d759
1f90dbc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# When customizing the model, you need to set the model to call some model...like # When customizing the model, you need to set the model to call deepseek-chat model | ||
|
||
OPENAI_COMPATIBILIY_ModelType=deepseek-chat | ||
OPENAI_COMPATIBILIY_API_BASE_URL=https://api.deepseek.com | ||
OPENAI_COMPATIBILIY_API_KEY= |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# OmegaPRM OpenR | ||
|
||
This demo code refers to [OpenR](https://github.com/openreasoner/openr), which is released under the MIT License. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[ | ||
{ | ||
"problem": "A bag contains 6 red balls, 4 blue balls, and 5 green balls. If you draw 3 balls without replacement, what is the probability of drawing exactly 2 red balls and 1 blue ball? Express your answer as a fraction in simplest form.", | ||
"final_answer": "1/11" | ||
}, | ||
{ | ||
"problem": "In how many ways can 8 different books be arranged on a shelf if 3 specific books must always be next to each other (in any order among themselves)?", | ||
"final_answer": "720" | ||
}, | ||
{ | ||
"problem": "A palindrome is a number that reads the same forwards and backwards. How many palindromes are there between 100 and 999?", | ||
"final_answer": "90" | ||
}, | ||
{ | ||
"problem": "Three fair coins are tossed. What is the probability of getting at least two heads? Express your answer as a fraction in simplest form.", | ||
"final_answer": "1/2" | ||
} | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | ||
import json | ||
import logging | ||
|
||
import yaml | ||
from model_utils import LM | ||
from module import ( | ||
Node, | ||
calculate_mc_score, | ||
perform_rollouts, | ||
process_annotations, | ||
) | ||
|
||
|
||
def load_config(config_path): | ||
""" | ||
Load configuration from a YAML file. | ||
|
||
Args: | ||
config_path (str): Path to the YAML configuration file. | ||
|
||
Returns: | ||
dict: A dictionary containing the configuration. | ||
""" | ||
with open(config_path, 'r') as file: | ||
return yaml.safe_load(file) | ||
|
||
|
||
def load_json_file(file_path): | ||
""" | ||
Load data from a JSON file. | ||
|
||
Args: | ||
file_path (str): Path to the JSON file. | ||
|
||
Returns: | ||
list: A list of dictionaries containing the problem and final answer. | ||
""" | ||
with open(file_path, 'r') as file: | ||
data = json.load(file) | ||
return data | ||
|
||
|
||
def setup_logging(log_file): | ||
""" | ||
Set up logging configuration to output to file and console. | ||
|
||
Args: | ||
log_file (str): Path to the log file. | ||
""" | ||
logging.basicConfig( | ||
filename=log_file, | ||
level=logging.INFO, | ||
format='%(asctime)s - %(levelname)s - %(message)s', | ||
datefmt='%Y-%m-%d %H:%M:%S', | ||
) | ||
|
||
console_handler = logging.StreamHandler() | ||
console_handler.setLevel(logging.INFO) | ||
formatter = logging.Formatter( | ||
'%(asctime)s - %(levelname)s - %(message)s', | ||
datefmt='%Y-%m-%d %H:%M:%S', | ||
) | ||
console_handler.setFormatter(formatter) | ||
|
||
root_logger = logging.getLogger() | ||
root_logger.addHandler(console_handler) | ||
|
||
logging.getLogger("openai").setLevel(logging.ERROR) | ||
logging.getLogger("httpx").setLevel(logging.WARNING) | ||
|
||
|
||
def main(): | ||
# Direct configuration instead of loading from yaml | ||
config = { | ||
'input': {'json_file_path': 'example_problems.json'}, | ||
'output': { | ||
'file_prefix': 'example', | ||
'log_file_path': 'example_processing.log', | ||
}, | ||
'processing': { | ||
'initial_rollouts': 30, # 增加初始rollouts数量 | ||
'num_rollouts': 25, # 增加每次迭代的rollouts数量 | ||
'max_iterations': 150, # 增加最大迭代次数 | ||
}, | ||
'model': { | ||
'model_type': 'camel', | ||
'model_name': 'deepseek-chat', | ||
'model_args': { | ||
'max_tokens': 300, # 增加最大token数 | ||
'temperature_range': [0.6, 0.9], # 调整温度范围,增加多样性 | ||
}, | ||
}, | ||
} | ||
Comment on lines
+85
to
+106
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi Yifeng please comment in English. |
||
|
||
# Set up logging | ||
setup_logging(config['output']['log_file_path']) | ||
logging.info("Starting data generation process...") | ||
|
||
# Load problems from JSON file | ||
problems = load_json_file(config['input']['json_file_path']) | ||
logging.info(f"Loaded {len(problems)} problems") | ||
|
||
# Initialize the language model | ||
model = LM( | ||
model_type=config['model']['model_type'], | ||
model_name=config['model']['model_name'], | ||
model_args=config['model']['model_args'], | ||
) | ||
|
||
# Process each problem | ||
for problem in problems: | ||
question = problem['problem'] | ||
final_answer = problem['final_answer'] | ||
|
||
# Create initial node | ||
initial_node = Node(question, "", final_answer) | ||
nodes = [initial_node] | ||
|
||
# Perform initial rollouts | ||
perform_rollouts( | ||
initial_node, model, config['processing']['initial_rollouts'] | ||
) | ||
calculate_mc_score(initial_node) | ||
|
||
# Process annotations | ||
process_annotations( | ||
question, | ||
nodes, | ||
model, | ||
filename=f"{config['output']['file_prefix']}_nodes_data.json", | ||
max_iterations=config['processing']['max_iterations'], | ||
) | ||
|
||
logging.info("Data generation process completed") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | ||
import os | ||
import random | ||
from typing import List, Optional | ||
|
||
from dotenv import load_dotenv | ||
from tqdm import tqdm | ||
|
||
from camel.agents import ChatAgent | ||
from camel.models import ModelFactory | ||
from camel.types import ModelPlatformType | ||
|
||
load_dotenv() | ||
|
||
|
||
class LM: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add comment to explain the purpose of the class by following the same format of any existing camel class. |
||
def __init__(self, model_type, model_name, num_rollouts=5, **kwargs): | ||
self.model_type = model_type | ||
self.model_name = model_name | ||
self.num_rollouts = num_rollouts | ||
self.max_tokens = kwargs.get('max_tokens', 4096) | ||
self.temperature_range = kwargs.get('temperature_range', [0.7, 1.0]) | ||
|
||
if self.model_type != "camel": | ||
raise ValueError("Only camel model type is supported") | ||
|
||
# Initialize camel model | ||
self.model = ModelFactory.create( | ||
model_platform=ModelPlatformType.OPENAI_COMPATIBLE_MODEL, | ||
model_type=os.environ.get("OPENAI_COMPATIBILIY_ModelType"), | ||
api_key=os.environ.get("OPENAI_COMPATIBILIY_API_KEY"), | ||
url=os.environ.get("OPENAI_COMPATIBILIY_API_BASE_URL"), | ||
model_config_dict={ | ||
"temperature": random.uniform(*self.temperature_range), | ||
"max_tokens": self.max_tokens, | ||
}, | ||
) | ||
|
||
# Initialize chat agent | ||
self.agent = ChatAgent( | ||
system_message='''You are a mathematical reasoning expert who | ||
always solves problems step by step. | ||
For each step: | ||
1. Write down what you're calculating | ||
2. Show the calculation | ||
3. Explain the result | ||
Always show your work, even for simple calculations. | ||
End your solution with the final numerical answer.''', | ||
Comment on lines
+55
to
+60
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please format the code by using intents. |
||
model=self.model, | ||
message_window_size=10, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you limit the window size to 10? |
||
) | ||
|
||
def generate(self, question, partial_answer, num_rollouts=None): | ||
results = [] | ||
if num_rollouts is None: | ||
num_rollouts = self.num_rollouts | ||
|
||
for _ in tqdm(range(num_rollouts)): | ||
# Update temperature for each rollout | ||
self.model.model_config_dict["temperature"] = random.uniform( | ||
*self.temperature_range | ||
) | ||
|
||
# Construct the prompt | ||
if partial_answer: | ||
prompt = f"""Problem: {question} | ||
Current solution steps: | ||
{partial_answer} | ||
Continue the solution, showing all steps and calculations. | ||
Make sure to explain each step:""" | ||
else: | ||
prompt = f"""Problem: {question} | ||
Please solve this step by step, showing all calculations and | ||
explaining each step. | ||
Remember to: | ||
1. Break down the problem | ||
2. Show all calculations | ||
3. Explain each step | ||
4. End with the final numerical answer.""" | ||
Comment on lines
+79
to
+91
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please format the code using intents. |
||
|
||
# Get response from agent | ||
response = self.agent.step(prompt) | ||
results.append(response.msgs[0].content) | ||
|
||
return results | ||
|
||
def generate_rollouts( | ||
self, prompt: str, num_copies: Optional[int] = None | ||
) -> List[str]: | ||
"""Generate multiple rollouts for a given prompt. | ||
|
||
Args: | ||
prompt (str): The input prompt to generate responses for | ||
num_copies (Optional[int], optional): Number of copies to generate. | ||
Defaults to None. | ||
|
||
Returns: | ||
List[str]: List of generated responses | ||
""" | ||
if num_copies is None: | ||
num_copies = 1 | ||
|
||
rollouts = [] | ||
for _ in range(num_copies): | ||
response = self.generate(prompt, "") | ||
if response: | ||
rollouts.append(response[0]) | ||
|
||
return rollouts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please update all the comments format by using r""".