Skip to content

Commit

Permalink
replace with gpt2 ppo
Browse files Browse the repository at this point in the history
  • Loading branch information
monetjoe committed Jun 1, 2024
1 parent 4f1ce50 commit f95c1f1
Showing 1 changed file with 253 additions and 93 deletions.
346 changes: 253 additions & 93 deletions rl.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,270 @@
import os
import torch
import random
from torch.optim import Adam
from transformers import GPT2Config
from modelscope.msdatasets import MsDataset
from utils import TunesFormer, Patchilizer, download, DEVICE
from generate import infer_abc
from config import *
import numpy as np
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config


class TextGenerationEnvironment:
def __init__(self, model_name_or_path, max_length=20):
self.tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
self.model = GPT2LMHeadModel.from_pretrained(model_name_or_path)
self.max_length = max_length
self.current_text = ""
self.current_length = 0

def generate_text(self, input_text, max_length=None):
if max_length is None:
max_length = self.max_length

input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
output = self.model.generate(input_ids=input_ids, max_length=max_length)
generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text

def get_tokenizer(self):
return self.tokenizer

def reset(self):
self.current_text = ""
self.current_length = 0
return ""

def step(self, action):
action_token = self.tokenizer.decode([action])
self.current_text += action_token

obs = self.current_text[-self.max_length :]
obs = obs if len(obs) > 0 else " " # Ensure the observation is never empty

reward = self.reward_fn(self.current_text)
self.current_length += 1
done = self.current_length >= self.max_length

return obs, reward, done, {}


class ModifiedGPT(nn.Module):
def __init__(self, model_name_or_path, num_actions=512):
super(ModifiedGPT, self).__init__()
config = GPT2Config.from_pretrained(model_name_or_path)
self.gpt = GPT2LMHeadModel.from_pretrained(model_name_or_path, config=config)
self.action_layer = nn.Linear(config.n_embd, num_actions)
self.softmax = nn.Softmax(dim=-1)

def forward(self, input_ids=None, attention_mask=None, labels=None):
gpt_outputs = self.gpt(input_ids=input_ids, attention_mask=attention_mask)
hidden_states = gpt_outputs[0]
# Calculate attention weights using the additional component
action = self.action_layer(hidden_states)
action_probs = self.softmax(action)

outputs = (action_probs,) + gpt_outputs[1:]

if labels is not None:
# Calculate the loss with the labels provided
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
outputs[0].view(-1, self.config.vocab_size), labels.view(-1)
)
outputs = (loss,) + outputs

return outputs


class RewardFunction:
def __init__(self, reference_text, model_name_or_path):
self.reference_tokens = GPT2Tokenizer.from_pretrained(
model_name_or_path
).encode(reference_text, return_tensors="pt")
self.model_name_or_path = model_name_or_path
self.model = GPT2LMHeadModel.from_pretrained(model_name_or_path)

def __call__(self, generated_text):
generated_tokens = GPT2Tokenizer.from_pretrained(
self.model_name_or_path
).encode(generated_text, return_tensors="pt")

with torch.no_grad():
outputs = self.model(input_ids=generated_tokens, labels=generated_tokens)
loss = outputs[0]

perplexity = torch.exp(loss)
reward = -perplexity.item()
return reward


class RolloutStorage:
def __init__(self, max_length):
self.observations = []
self.actions = []
self.action_probs = []
self.rewards = []
self.dones = []
self.max_length = max_length

def store(self, obs, action, action_probs, reward, done):
self.observations.append(obs)
self.actions.append(action)
self.action_probs.append(action_probs)
self.rewards.append(reward)
self.dones.append(done)

def store_last_observation(self, value):
self.observations.append(value)

def clear(self):
self.observations = []
self.actions = []
self.action_probs = []
self.rewards = []
self.dones = []

def compute_returns(self, gamma):
returns = []
R = 0
for r, done in zip(reversed(self.rewards), reversed(self.dones)):
if done:
R = 0
R = r + gamma * R
returns.insert(0, R)

return returns

def compute_advantages(self, returns):
advantages = []
for action_prob, return_ in zip(self.action_probs, returns):
advantages.append(return_ - action_prob)

return advantages

def batch_by_indices(self, indices):
obs_batch = [self.observations[i] for i in indices]
action_batch = [self.actions[i] for i in indices]
action_prob_batch = [self.action_probs[i] for i in indices]
advantage_batch = [self.advantages[i] for i in indices]
return_batch = [self.returns[i] for i in indices]
return obs_batch, action_batch, action_prob_batch, advantage_batch, return_batch

def __len__(self):
return len(self.actions)


class PPOTrainer:
def __init__(
self,
model: TunesFormer,
patchilizer: Patchilizer,
lr=1e-5,
env,
model,
reward_fn,
lr=1e-4,
betas=(0.9, 0.999),
eps=1e-5,
gamma=0.99,
clip_param=0.2,
value_loss_coef=0.5,
entropy_coef=0.01,
num_epochs=10,
batch_size=64,
):
self.env = env
self.model = model
self.patchilizer = patchilizer
self.optimizer = Adam(self.model.parameters(), lr=lr)

def _rewards(self, generated_abc):
# TODO: Placeholder - Reward computation logic
rewards = [1.0] * len(generated_abc)
return torch.tensor(rewards)

def _str2tensor(self, input_str: str):
# 将字符串转换成张量
tensor = torch.tensor([float(char) for char in input_str])
return tensor

def train(self, prompts: list, epochs=500):
for epoch in range(epochs):
for i, prompt in enumerate(prompts):
# Generate outputs
with torch.no_grad():
generated_abc = infer_abc(prompt, self.patchilizer, self.model)

# Compute rewards
rewards = self._rewards(generated_abc)

# Compute policy loss
logits = self.model(prompt)
log_probs = torch.log_softmax(logits, dim=-1)
target_ids = self._str2tensor(generated_abc)[:, 1:].reshape(-1)
log_probs = log_probs[:, :-1, :].reshape(-1, log_probs.size(-1))
log_probs = log_probs.gather(1, target_ids.unsqueeze(1)).squeeze(1)
policy_loss = -(log_probs * rewards).mean()

# Optimize model
self.optimizer.zero_grad()
policy_loss.backward()
self.optimizer.step()

print(
f"Epoch {epoch + 1}/{epochs}, Batch {i + 1}/{len(prompts)}, Loss: {policy_loss.item()}"
)


def init_model():
patch_config = GPT2Config(
num_hidden_layers=PATCH_NUM_LAYERS,
max_length=PATCH_LENGTH,
max_position_embeddings=PATCH_LENGTH,
vocab_size=1,
)
self.reward_fn = reward_fn
self.gamma = gamma
self.clip_param = clip_param
self.value_loss_coef = value_loss_coef
self.entropy_coef = entropy_coef
self.num_epochs = num_epochs # Add this line
self.batch_size = batch_size

char_config = GPT2Config(
num_hidden_layers=CHAR_NUM_LAYERS,
max_length=PATCH_SIZE,
max_position_embeddings=PATCH_SIZE,
vocab_size=128,
)
self.optimizer = torch.optim.Adam(
model.parameters(), lr=lr, betas=betas, eps=eps
)

model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
def train(self, num_steps):
storage = RolloutStorage(self.env.max_length)
obs = self.env.reset()
done = True # Add this line to initialize the 'done' variable

# Move model to GPU if available
if not os.path.exists(WEIGHT_PATH):
download()
for _ in range(num_steps):
if done:
obs = self.env.reset()

checkpoint = torch.load(WEIGHT_PATH)
for _ in range(self.env.max_length):
obs_tensor = torch.tensor(
self.env.tokenizer.encode(obs), dtype=torch.long
).unsqueeze(0)
if obs_tensor.shape[-1] == 0:
continue

if torch.cuda.device_count() > 1:
model.module.load_state_dict(checkpoint["model"])
else:
model.load_state_dict(checkpoint["model"])
action_probs_tensor, value_tensor = self.model(obs_tensor)
action_probs = action_probs_tensor.squeeze(0).detach().numpy()
action = np.random.choice(len(action_probs), p=action_probs)

return model.to(DEVICE), Patchilizer()
next_obs, reward, done, _ = self.env.step(action)
storage.store(obs, action, action_probs, reward, done)
obs = next_obs

if done:
break

if not done:
obs_tensor = torch.tensor(
self.env.tokenizer.encode(obs), dtype=torch.long
).unsqueeze(0)
_, value_tensor = self.model(obs_tensor)
storage.store_last_observation(value_tensor)
else:
storage.store_last_observation(torch.tensor(0.0))

# returns = storage.compute_returns(self.gamma)
# advantages = storage.compute_advantages(returns)

for _ in range(self.num_epochs):
indices = np.arange(len(storage))
np.random.shuffle(indices)

for batch_start in range(0, len(storage), self.batch_size):
batch_indices = indices[batch_start : batch_start + self.batch_size]
(
obs_batch,
action_batch,
action_prob_batch,
advantage_batch,
return_batch,
) = storage.batch_by_indices(batch_indices)

self.update(
obs_batch,
action_batch,
action_prob_batch,
advantage_batch,
return_batch,
)

storage.clear()


if __name__ == "__main__":
# Initialize TunesFormer model and tokenizer
model, patchilizer = init_model()
# Initialize PPO trainer for TunesFormer
ppo_trainer = PPOTrainer(model, patchilizer)
# load prompts from the dataset
trainset = MsDataset.load(f"monetjoe/{DATASET}", split="train")
evalset = MsDataset.load(f"monetjoe/{DATASET}", split="test")
prompts = set("A:Q1\n", "A:Q2\n", "A:Q3\n", "A:Q4\n", "")
for item in list(trainset) + list(evalset):
prompts.add("A:" + item["label"] + "\n" + item["prompt"] + "\n")
prompts.add(item["prompt"] + "\n")

prompts = list(prompts)
random.shuffle(prompts)
# Train the model
ppo_trainer.train(prompts)
env = TextGenerationEnvironment(model_name_or_path="gpt2", max_length=20)
model = ModifiedGPT(model_name_or_path="gpt2", num_actions=512)
reward_fn = RewardFunction(
reference_text="The quick brown fox jumps over the lazy dog",
model_name_or_path="gpt2",
)
trainer = PPOTrainer(
env,
model,
reward_fn,
lr=1e-4,
betas=(0.9, 0.999),
eps=1e-5,
gamma=0.99,
clip_param=0.2,
value_loss_coef=0.5,
entropy_coef=0.01,
)

num_steps = 10000
trainer.train(num_steps)

# Save the trained model
torch.save(model.state_dict(), "./output/modified_gpt_model.pth")

0 comments on commit f95c1f1

Please sign in to comment.