diff --git a/rl.py b/rl.py index 5bfa69d..e9e9b2c 100644 --- a/rl.py +++ b/rl.py @@ -1,110 +1,270 @@ -import os import torch -import random -from torch.optim import Adam -from transformers import GPT2Config -from modelscope.msdatasets import MsDataset -from utils import TunesFormer, Patchilizer, download, DEVICE -from generate import infer_abc -from config import * +import numpy as np +import torch.nn as nn +from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config + + +class TextGenerationEnvironment: + def __init__(self, model_name_or_path, max_length=20): + self.tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path) + self.model = GPT2LMHeadModel.from_pretrained(model_name_or_path) + self.max_length = max_length + self.current_text = "" + self.current_length = 0 + + def generate_text(self, input_text, max_length=None): + if max_length is None: + max_length = self.max_length + + input_ids = self.tokenizer.encode(input_text, return_tensors="pt") + output = self.model.generate(input_ids=input_ids, max_length=max_length) + generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True) + return generated_text + + def get_tokenizer(self): + return self.tokenizer + + def reset(self): + self.current_text = "" + self.current_length = 0 + return "" + + def step(self, action): + action_token = self.tokenizer.decode([action]) + self.current_text += action_token + + obs = self.current_text[-self.max_length :] + obs = obs if len(obs) > 0 else " " # Ensure the observation is never empty + + reward = self.reward_fn(self.current_text) + self.current_length += 1 + done = self.current_length >= self.max_length + + return obs, reward, done, {} + + +class ModifiedGPT(nn.Module): + def __init__(self, model_name_or_path, num_actions=512): + super(ModifiedGPT, self).__init__() + config = GPT2Config.from_pretrained(model_name_or_path) + self.gpt = GPT2LMHeadModel.from_pretrained(model_name_or_path, config=config) + self.action_layer = nn.Linear(config.n_embd, num_actions) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, input_ids=None, attention_mask=None, labels=None): + gpt_outputs = self.gpt(input_ids=input_ids, attention_mask=attention_mask) + hidden_states = gpt_outputs[0] + # Calculate attention weights using the additional component + action = self.action_layer(hidden_states) + action_probs = self.softmax(action) + + outputs = (action_probs,) + gpt_outputs[1:] + + if labels is not None: + # Calculate the loss with the labels provided + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + outputs[0].view(-1, self.config.vocab_size), labels.view(-1) + ) + outputs = (loss,) + outputs + + return outputs + + +class RewardFunction: + def __init__(self, reference_text, model_name_or_path): + self.reference_tokens = GPT2Tokenizer.from_pretrained( + model_name_or_path + ).encode(reference_text, return_tensors="pt") + self.model_name_or_path = model_name_or_path + self.model = GPT2LMHeadModel.from_pretrained(model_name_or_path) + + def __call__(self, generated_text): + generated_tokens = GPT2Tokenizer.from_pretrained( + self.model_name_or_path + ).encode(generated_text, return_tensors="pt") + + with torch.no_grad(): + outputs = self.model(input_ids=generated_tokens, labels=generated_tokens) + loss = outputs[0] + + perplexity = torch.exp(loss) + reward = -perplexity.item() + return reward + + +class RolloutStorage: + def __init__(self, max_length): + self.observations = [] + self.actions = [] + self.action_probs = [] + self.rewards = [] + self.dones = [] + self.max_length = max_length + + def store(self, obs, action, action_probs, reward, done): + self.observations.append(obs) + self.actions.append(action) + self.action_probs.append(action_probs) + self.rewards.append(reward) + self.dones.append(done) + + def store_last_observation(self, value): + self.observations.append(value) + + def clear(self): + self.observations = [] + self.actions = [] + self.action_probs = [] + self.rewards = [] + self.dones = [] + + def compute_returns(self, gamma): + returns = [] + R = 0 + for r, done in zip(reversed(self.rewards), reversed(self.dones)): + if done: + R = 0 + R = r + gamma * R + returns.insert(0, R) + + return returns + + def compute_advantages(self, returns): + advantages = [] + for action_prob, return_ in zip(self.action_probs, returns): + advantages.append(return_ - action_prob) + + return advantages + + def batch_by_indices(self, indices): + obs_batch = [self.observations[i] for i in indices] + action_batch = [self.actions[i] for i in indices] + action_prob_batch = [self.action_probs[i] for i in indices] + advantage_batch = [self.advantages[i] for i in indices] + return_batch = [self.returns[i] for i in indices] + return obs_batch, action_batch, action_prob_batch, advantage_batch, return_batch + + def __len__(self): + return len(self.actions) class PPOTrainer: def __init__( self, - model: TunesFormer, - patchilizer: Patchilizer, - lr=1e-5, + env, + model, + reward_fn, + lr=1e-4, + betas=(0.9, 0.999), + eps=1e-5, + gamma=0.99, + clip_param=0.2, + value_loss_coef=0.5, + entropy_coef=0.01, + num_epochs=10, + batch_size=64, ): + self.env = env self.model = model - self.patchilizer = patchilizer - self.optimizer = Adam(self.model.parameters(), lr=lr) - - def _rewards(self, generated_abc): - # TODO: Placeholder - Reward computation logic - rewards = [1.0] * len(generated_abc) - return torch.tensor(rewards) - - def _str2tensor(self, input_str: str): - # 将字符串转换成张量 - tensor = torch.tensor([float(char) for char in input_str]) - return tensor - - def train(self, prompts: list, epochs=500): - for epoch in range(epochs): - for i, prompt in enumerate(prompts): - # Generate outputs - with torch.no_grad(): - generated_abc = infer_abc(prompt, self.patchilizer, self.model) - - # Compute rewards - rewards = self._rewards(generated_abc) - - # Compute policy loss - logits = self.model(prompt) - log_probs = torch.log_softmax(logits, dim=-1) - target_ids = self._str2tensor(generated_abc)[:, 1:].reshape(-1) - log_probs = log_probs[:, :-1, :].reshape(-1, log_probs.size(-1)) - log_probs = log_probs.gather(1, target_ids.unsqueeze(1)).squeeze(1) - policy_loss = -(log_probs * rewards).mean() - - # Optimize model - self.optimizer.zero_grad() - policy_loss.backward() - self.optimizer.step() - - print( - f"Epoch {epoch + 1}/{epochs}, Batch {i + 1}/{len(prompts)}, Loss: {policy_loss.item()}" - ) - - -def init_model(): - patch_config = GPT2Config( - num_hidden_layers=PATCH_NUM_LAYERS, - max_length=PATCH_LENGTH, - max_position_embeddings=PATCH_LENGTH, - vocab_size=1, - ) + self.reward_fn = reward_fn + self.gamma = gamma + self.clip_param = clip_param + self.value_loss_coef = value_loss_coef + self.entropy_coef = entropy_coef + self.num_epochs = num_epochs # Add this line + self.batch_size = batch_size - char_config = GPT2Config( - num_hidden_layers=CHAR_NUM_LAYERS, - max_length=PATCH_SIZE, - max_position_embeddings=PATCH_SIZE, - vocab_size=128, - ) + self.optimizer = torch.optim.Adam( + model.parameters(), lr=lr, betas=betas, eps=eps + ) - model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS) - if torch.cuda.device_count() > 1: - model = torch.nn.DataParallel(model) + def train(self, num_steps): + storage = RolloutStorage(self.env.max_length) + obs = self.env.reset() + done = True # Add this line to initialize the 'done' variable - # Move model to GPU if available - if not os.path.exists(WEIGHT_PATH): - download() + for _ in range(num_steps): + if done: + obs = self.env.reset() - checkpoint = torch.load(WEIGHT_PATH) + for _ in range(self.env.max_length): + obs_tensor = torch.tensor( + self.env.tokenizer.encode(obs), dtype=torch.long + ).unsqueeze(0) + if obs_tensor.shape[-1] == 0: + continue - if torch.cuda.device_count() > 1: - model.module.load_state_dict(checkpoint["model"]) - else: - model.load_state_dict(checkpoint["model"]) + action_probs_tensor, value_tensor = self.model(obs_tensor) + action_probs = action_probs_tensor.squeeze(0).detach().numpy() + action = np.random.choice(len(action_probs), p=action_probs) - return model.to(DEVICE), Patchilizer() + next_obs, reward, done, _ = self.env.step(action) + storage.store(obs, action, action_probs, reward, done) + obs = next_obs + + if done: + break + + if not done: + obs_tensor = torch.tensor( + self.env.tokenizer.encode(obs), dtype=torch.long + ).unsqueeze(0) + _, value_tensor = self.model(obs_tensor) + storage.store_last_observation(value_tensor) + else: + storage.store_last_observation(torch.tensor(0.0)) + + # returns = storage.compute_returns(self.gamma) + # advantages = storage.compute_advantages(returns) + + for _ in range(self.num_epochs): + indices = np.arange(len(storage)) + np.random.shuffle(indices) + + for batch_start in range(0, len(storage), self.batch_size): + batch_indices = indices[batch_start : batch_start + self.batch_size] + ( + obs_batch, + action_batch, + action_prob_batch, + advantage_batch, + return_batch, + ) = storage.batch_by_indices(batch_indices) + + self.update( + obs_batch, + action_batch, + action_prob_batch, + advantage_batch, + return_batch, + ) + + storage.clear() if __name__ == "__main__": - # Initialize TunesFormer model and tokenizer - model, patchilizer = init_model() - # Initialize PPO trainer for TunesFormer - ppo_trainer = PPOTrainer(model, patchilizer) - # load prompts from the dataset - trainset = MsDataset.load(f"monetjoe/{DATASET}", split="train") - evalset = MsDataset.load(f"monetjoe/{DATASET}", split="test") - prompts = set("A:Q1\n", "A:Q2\n", "A:Q3\n", "A:Q4\n", "") - for item in list(trainset) + list(evalset): - prompts.add("A:" + item["label"] + "\n" + item["prompt"] + "\n") - prompts.add(item["prompt"] + "\n") - - prompts = list(prompts) - random.shuffle(prompts) - # Train the model - ppo_trainer.train(prompts) + env = TextGenerationEnvironment(model_name_or_path="gpt2", max_length=20) + model = ModifiedGPT(model_name_or_path="gpt2", num_actions=512) + reward_fn = RewardFunction( + reference_text="The quick brown fox jumps over the lazy dog", + model_name_or_path="gpt2", + ) + trainer = PPOTrainer( + env, + model, + reward_fn, + lr=1e-4, + betas=(0.9, 0.999), + eps=1e-5, + gamma=0.99, + clip_param=0.2, + value_loss_coef=0.5, + entropy_coef=0.01, + ) + + num_steps = 10000 + trainer.train(num_steps) + + # Save the trained model + torch.save(model.state_dict(), "./output/modified_gpt_model.pth")