-
Notifications
You must be signed in to change notification settings - Fork 8
/
ppo_flan_sentiments.py
98 lines (83 loc) · 3.36 KB
/
ppo_flan_sentiments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import pathlib
from typing import List
import numpy as np
import torch
import trlx
import yaml
from datasets import load_dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
from trlx.data.configs import TRLConfig
config_path = pathlib.Path(__file__).parent.joinpath(
"./configs/ppo_flan_sentiments.yml"
)
with config_path.open() as f:
default_config = yaml.safe_load(f)
class ZeroShotRewardModel:
def __init__(self) -> None:
if torch.cuda.is_available():
self.device = int(os.environ.get("LOCAL_RANK", 0))
self.tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
self.model = T5ForConditionalGeneration.from_pretrained(
"google/flan-t5-large"
).to(self.device)
self.yes_token_id = 2163 # this is for Flan-T5, change it accordingly
self.no_token_id = 465 # this is for Flan-T5, change it accordingly
def reward_fn(self, samples: List[str], **kwargs) -> List[float]:
scores = []
for sample in samples:
input_text = (
f"Review: {sample}\n\n Is this movie review positive? Response:"
)
x = self.tokenizer([input_text], return_tensors="pt").input_ids.to(
self.device
)
outputs = self.model.generate(
x, return_dict_in_generate=True, output_scores=True, max_new_tokens=1
)
v_yes_exp = (
torch.exp(outputs.scores[0][:, self.yes_token_id]).cpu().numpy()[0]
)
v_no_exp = (
torch.exp(outputs.scores[0][:, self.no_token_id]).cpu().numpy()[0]
)
scores.append(
(v_yes_exp / (v_yes_exp + v_no_exp) - 0.5) * 10
) # we do some rescaling to improve PPO. This is Eq. (3) in the paper
return scores
def metric_fn(self, samples: List[str], **kwargs) -> List[float]:
"""Similar to reward_fn, but without rescaling, to make it interpretable in the logs."""
scores = []
for sample in samples:
input_text = (
f"Review: {sample}\n\n Is this movie review positive? Response:"
)
x = self.tokenizer([input_text], return_tensors="pt").input_ids.to(
self.device
)
outputs = self.model.generate(
x, return_dict_in_generate=True, output_scores=True, max_new_tokens=1
)
v_yes_exp = (
torch.exp(outputs.scores[0][:, self.yes_token_id]).cpu().numpy()[0]
)
v_no_exp = (
torch.exp(outputs.scores[0][:, self.no_token_id]).cpu().numpy()[0]
)
scores.append(v_yes_exp / (v_yes_exp + v_no_exp))
return {"prob_positive": scores}
def main(hparams={}):
config = TRLConfig.update(default_config, hparams)
# Load the reward model
reward_model = ZeroShotRewardModel()
# Take few words off of movies reviews as prompts
imdb = load_dataset("imdb", split="train+test")
trlx.train(
reward_fn=reward_model.reward_fn,
prompts=[" ".join(review.split()[:4]) for review in imdb["text"][:-64]],
metric_fn=reward_model.metric_fn,
eval_prompts=[" ".join(review.split()[:4]) for review in imdb["text"][-64:]],
config=config,
)
if __name__ == "__main__":
main()