-
Notifications
You must be signed in to change notification settings - Fork 731
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
501 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# When customizing the model, you need to set the model to call some model...like # When customizing the model, you need to set the model to call deepseek-chat model | ||
|
||
OPENAI_COMPATIBILIY_ModelType=deepseek-chat | ||
OPENAI_COMPATIBILIY_API_BASE_URL=https://api.deepseek.com | ||
OPENAI_COMPATIBILIY_API_KEY= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
input: | ||
json_file_path: 'example_problems.json' | ||
|
||
output: | ||
file_prefix: 'example' | ||
log_file_path: 'example_processing.log' | ||
|
||
processing: | ||
initial_rollouts: 20 | ||
num_rollouts: 20 | ||
max_iterations: 100 | ||
|
||
model: | ||
model_type: "camel" | ||
model_name: "deepseek-chat" | ||
model_args: | ||
max_tokens: 200 | ||
temperature_range: [0.7, 1.0] | ||
|
||
|
||
# There are 32 initial scroll execution questions to choose from. | ||
# If the questions are too easy (all right) or too hard (all wrong), | ||
# then OmegaPRM will not deal with them. | ||
# So you may need a problem set of the right difficulty and a model that is not too strong, | ||
# such as qwen's open source small model. Then the node node data may be obtained |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
[ | ||
{ | ||
"problem": "How many ways can we put 3 math books and 5 English books on a shelf if all the math books must stay together and all the English books must also stay together? (The math books are all different and so are the English books.)", | ||
"final_answer": "1440" | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | ||
import json | ||
import logging | ||
|
||
import yaml | ||
from model_utils import LM | ||
from module import ( | ||
Node, | ||
calculate_mc_score, | ||
perform_rollouts, | ||
process_annotations, | ||
) | ||
|
||
|
||
def load_config(config_path): | ||
""" | ||
Load configuration from a YAML file. | ||
Args: | ||
config_path (str): Path to the YAML configuration file. | ||
Returns: | ||
dict: A dictionary containing the configuration. | ||
""" | ||
with open(config_path, 'r') as file: | ||
return yaml.safe_load(file) | ||
|
||
|
||
def load_json_file(file_path): | ||
""" | ||
Load data from a JSON file. | ||
Args: | ||
file_path (str): Path to the JSON file. | ||
Returns: | ||
list: A list of dictionaries containing the problem and final answer. | ||
""" | ||
with open(file_path, 'r') as file: | ||
data = json.load(file) | ||
return data | ||
|
||
|
||
def setup_logging(log_file): | ||
""" | ||
Set up logging configuration to output to file and console. | ||
Args: | ||
log_file (str): Path to the log file. | ||
""" | ||
logging.basicConfig( | ||
filename=log_file, | ||
level=logging.INFO, | ||
format='%(asctime)s - %(levelname)s - %(message)s', | ||
datefmt='%Y-%m-%d %H:%M:%S', | ||
) | ||
|
||
console_handler = logging.StreamHandler() | ||
console_handler.setLevel(logging.INFO) | ||
formatter = logging.Formatter( | ||
'%(asctime)s - %(levelname)s - %(message)s', | ||
datefmt='%Y-%m-%d %H:%M:%S', | ||
) | ||
console_handler.setFormatter(formatter) | ||
|
||
root_logger = logging.getLogger() | ||
root_logger.addHandler(console_handler) | ||
|
||
logging.getLogger("openai").setLevel(logging.ERROR) | ||
logging.getLogger("httpx").setLevel(logging.WARNING) | ||
|
||
|
||
def main(): | ||
# Load configuration | ||
config = load_config('config.yaml') | ||
|
||
# Get parameters from config | ||
json_file_path = config['input']['json_file_path'] | ||
log_file_path = config['output']['log_file_path'] | ||
file_prefix = config['output']['file_prefix'] | ||
num_rollouts = config['processing']['num_rollouts'] | ||
initial_rollouts = config['processing']['initial_rollouts'] | ||
max_iterations = config['processing']['max_iterations'] | ||
|
||
lm_model = LM( | ||
model_type=config['model']['model_type'], | ||
model_name=config['model']['model_name'], | ||
num_rollouts=num_rollouts, | ||
**config['model']['model_args'], | ||
) | ||
|
||
# Set up logging | ||
setup_logging(log_file_path) | ||
|
||
# Start the process and log it | ||
logging.info("Started processing the JSON file.") | ||
|
||
# Load the JSON data | ||
data = load_json_file(json_file_path) | ||
|
||
# Process each problem and its final answer | ||
for i, item in enumerate(data): | ||
problem = item.get('problem', 'No problem found') | ||
final_answer = item.get('final_answer', 'No answer found') | ||
|
||
# Log each problem and answer | ||
logging.info(f"Processed Problem {i + 1}: {problem}") | ||
logging.info(f"Final Answer: {final_answer}") | ||
|
||
# Initialize the root node and perform rollouts | ||
nodes = [] | ||
root_node = Node(problem, "", final_answer) | ||
rollouts, correctness_flags = perform_rollouts( | ||
root_node, lm_model, initial_rollouts | ||
) | ||
mc_score = calculate_mc_score(root_node) | ||
root_node.mc_score = mc_score | ||
|
||
nodes.append(root_node) | ||
|
||
# Check if further processing is needed | ||
if 0 < sum(correctness_flags) < initial_rollouts: | ||
print("Processing annotations ...\n") | ||
filename = f"{file_prefix}_{i+1}_nodes_data.json" | ||
process_annotations( | ||
problem, nodes, lm_model, filename, max_iterations | ||
) | ||
|
||
# Log completion | ||
logging.info("Finished processing the JSON file.") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | ||
import os | ||
import random | ||
|
||
from dotenv import load_dotenv | ||
from tqdm import tqdm | ||
|
||
from camel.agents import ChatAgent | ||
from camel.models import ModelFactory | ||
from camel.types import ModelPlatformType | ||
|
||
load_dotenv() | ||
|
||
|
||
class LM: | ||
def __init__(self, model_type, model_name, num_rollouts=5, **kwargs): | ||
self.model_type = model_type | ||
self.model_name = model_name | ||
self.num_rollouts = num_rollouts | ||
self.max_tokens = kwargs.get('max_tokens', 4096) | ||
self.temperature_range = kwargs.get('temperature_range', [0.7, 1.0]) | ||
|
||
if self.model_type != "camel": | ||
raise ValueError("Only camel model type is supported") | ||
|
||
# Initialize camel model | ||
self.model = ModelFactory.create( | ||
model_platform=ModelPlatformType.OPENAI_COMPATIBLE_MODEL, | ||
model_type=os.environ.get("OPENAI_COMPATIBILIY_ModelType"), | ||
api_key=os.environ.get("OPENAI_COMPATIBILIY_API_KEY"), | ||
url=os.environ.get("OPENAI_COMPATIBILIY_API_BASE_URL"), | ||
model_config_dict={ | ||
"temperature": random.uniform(*self.temperature_range), | ||
"max_tokens": self.max_tokens, | ||
}, | ||
) | ||
|
||
# Initialize chat agent | ||
self.agent = ChatAgent( | ||
system_message='''You are a mathematical reasoning expert who | ||
always solves problems step by step. | ||
For each step: | ||
1. Write down what you're calculating | ||
2. Show the calculation | ||
3. Explain the result | ||
Always show your work, even for simple calculations. | ||
End your solution with the final numerical answer.''', | ||
model=self.model, | ||
message_window_size=10, | ||
) | ||
|
||
def generate(self, question, partial_answer, num_rollouts=None): | ||
results = [] | ||
if num_rollouts is None: | ||
num_rollouts = self.num_rollouts | ||
|
||
for _ in tqdm(range(num_rollouts)): | ||
# Update temperature for each rollout | ||
self.model.model_config_dict["temperature"] = random.uniform( | ||
*self.temperature_range | ||
) | ||
|
||
# Construct the prompt | ||
if partial_answer: | ||
prompt = f"""Problem: {question} | ||
Current solution steps: | ||
{partial_answer} | ||
Continue the solution, showing all steps and calculations. | ||
Make sure to explain each step:""" | ||
else: | ||
prompt = f"""Problem: {question} | ||
Please solve this step by step, showing all calculations and | ||
explaining each step. | ||
Remember to: | ||
1. Break down the problem | ||
2. Show all calculations | ||
3. Explain each step | ||
4. End with the final numerical answer.""" | ||
|
||
# Get response from agent | ||
response = self.agent.step(prompt) | ||
results.append(response.msgs[0].content) | ||
|
||
return results |
Oops, something went wrong.