Skip to content

Commit

Permalink
Add support for contrastive search (#943)
Browse files Browse the repository at this point in the history
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
  • Loading branch information
skavulya and regisss authored Aug 5, 2024
1 parent 2e8a80d commit 4261160
Show file tree
Hide file tree
Showing 11 changed files with 793 additions and 87 deletions.
1 change: 1 addition & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ Here are a few settings you may be interested in:
- `--limit_hpu_graphs` to skip HPU Graph usage for first token to save memory
- `--use_kv_cache` to use the [key/value cache](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig.use_cache) to speed up generation
- `--do_sample` or `--num_beams` to generate new tokens doing sampling or beam search (greedy search is the default)
- `--top_k` and `--penalty_alpha` to generate new tokens doing contrastive search (greedy search is the default)
- `--prompt` to benchmark the model on one or several prompts of your choice
- `--attn_softmax_bf16` to run attention softmax layer in bfloat16 precision provided that the model (such as Llama) supports it
- `--trim_logits` to calculate logits only for the last token in the first time step provided that the model (such as Llama) supports it
Expand Down
12 changes: 12 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ def setup_parser(parser):
type=int,
help="Number of beams used for beam search generation. 1 means greedy search will be performed.",
)
parser.add_argument(
"--top_k",
default=None,
type=int,
help="Size of candidate set used for re-ranking in contrastive search. top_k > 1 enables contrastive search.",
)
parser.add_argument(
"--penalty_alpha",
default=None,
type=float,
help="Degeneration penalty for contrastive search. penalty_alpha > 0 enables contrastive search.",
)
parser.add_argument(
"--trim_logits",
action="store_true",
Expand Down
2 changes: 2 additions & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,8 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
generation_config.bucket_internal = args.bucket_internal
generation_config.do_sample = args.do_sample
generation_config.num_beams = args.num_beams
generation_config.top_k = args.top_k
generation_config.penalty_alpha = args.penalty_alpha
generation_config.bad_words_ids = bad_words_ids
generation_config.force_words_ids = force_words_ids
generation_config.num_return_sequences = args.num_return_sequences
Expand Down
613 changes: 542 additions & 71 deletions optimum/habana/transformers/generation/utils.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
GaudiGemmaForCausalLM,
GaudiGPT2Attention,
GaudiGPT2Block,
GaudiGPT2DoubleHeadsModel,
GaudiGPT2LMHeadModel,
GaudiGPTBigCodeForCausalLM,
GaudiGPTJAttention,
Expand Down Expand Up @@ -255,6 +256,7 @@ def adapt_transformers_to_gaudi():
transformers.generation.GenerationMixin._beam_search = GaudiGenerationMixin._beam_search
transformers.generation.GenerationMixin._group_beam_search = GaudiGenerationMixin._group_beam_search
transformers.generation.GenerationMixin._constrained_beam_search = GaudiGenerationMixin._constrained_beam_search
transformers.generation.GenerationMixin._contrastive_search = GaudiGenerationMixin._contrastive_search
transformers.generation.GenerationMixin._assisted_decoding = GaudiGenerationMixin._assisted_decoding
transformers.generation.GenerationMixin._get_candidate_generator = GaudiGenerationMixin._get_candidate_generator
transformers.generation.GenerationConfig = GaudiGenerationConfig
Expand Down Expand Up @@ -317,6 +319,7 @@ def adapt_transformers_to_gaudi():
transformers.models.gpt2.modeling_gpt2.GPT2Attention = GaudiGPT2Attention
transformers.models.gpt2.modeling_gpt2.GPT2Model.forward = gaudi_gpt2_forward
transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel = GaudiGPT2LMHeadModel
transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel = GaudiGPT2DoubleHeadsModel
transformers.models.gpt2.modeling_gpt2.GPT2Block = GaudiGPT2Block
models_with_tracing_support.extend((GaudiGPT2Attention, GaudiGPT2LMHeadModel))

Expand Down
8 changes: 7 additions & 1 deletion optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,13 @@
gaudi_gemma_attention_forward,
gaudi_gemma_model_forward,
)
from .gpt2 import GaudiGPT2Attention, GaudiGPT2Block, GaudiGPT2LMHeadModel, gaudi_gpt2_forward
from .gpt2 import (
GaudiGPT2Attention,
GaudiGPT2Block,
GaudiGPT2DoubleHeadsModel,
GaudiGPT2LMHeadModel,
gaudi_gpt2_forward,
)
from .gpt_bigcode import (
GaudiGPTBigCodeForCausalLM,
gaudi_gpt_bigcode_attention_forward,
Expand Down
8 changes: 7 additions & 1 deletion optimum/habana/transformers/models/gpt2/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
from .modeling_gpt2 import GaudiGPT2Attention, GaudiGPT2Block, GaudiGPT2LMHeadModel, gaudi_gpt2_forward
from .modeling_gpt2 import (
GaudiGPT2Attention,
GaudiGPT2Block,
GaudiGPT2DoubleHeadsModel,
GaudiGPT2LMHeadModel,
gaudi_gpt2_forward,
)
148 changes: 147 additions & 1 deletion optimum/habana/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
import torch
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2LMHeadModel, logger
from transformers.models.gpt2.modeling_gpt2 import (
GPT2MLP,
GPT2Attention,
GPT2DoubleHeadsModel,
GPT2DoubleHeadsModelOutput,
GPT2LMHeadModel,
logger,
)


class GaudiGPT2Attention(GPT2Attention):
Expand Down Expand Up @@ -586,3 +593,142 @@ def forward(
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)


class GaudiGPT2DoubleHeadsModel(GPT2DoubleHeadsModel):
"""
Copied from GPT2DoubleHeadsModel: https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/gpt2/modeling_gpt2.py#L1377
The only differences are:
- add new args token_idx to support static shapes
"""

def prepare_inputs_for_generation(
self, input_ids, inputs_embeds=None, past_key_values=None, token_idx=None, **kwargs
):
token_type_ids = kwargs.get("token_type_ids", None)
# Omit tokens covered by past_key_values
if past_key_values:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
else:
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

if token_type_ids is not None:
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]

attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = torch.index_select(position_ids, 1, token_idx - 1)
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
else:
position_ids = None

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}

model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"token_idx": token_idx,
}
)

return model_inputs

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
mc_token_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
mc_labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[Tuple, GPT2DoubleHeadsModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx,
)

hidden_states = transformer_outputs[0]

# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)

lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)

mc_loss = None
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
lm_loss = None
if labels is not None:
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

if not return_dict:
output = (lm_logits, mc_logits) + transformer_outputs[1:]
if mc_loss is not None:
output = (mc_loss,) + output
return ((lm_loss,) + output) if lm_loss is not None else output

return GPT2DoubleHeadsModelOutput(
loss=lm_loss,
mc_loss=mc_loss,
logits=lm_logits,
mc_logits=mc_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
17 changes: 17 additions & 0 deletions tests/test_text_generation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@
"distributed_tp": [
("meta-llama/Llama-2-7b-hf", 1345.2369318328463),
],
"contrastive_search": [
("gpt2-xl", 1, False, 51.61471298016438),
],
}
else:
# Gaudi1 CI baselines
Expand Down Expand Up @@ -106,6 +109,9 @@
"torch_compile": [],
"torch_compile_distributed": [],
"distributed_tp": [],
"contrastive_search": [
("gpt2-xl", 1, False, 34.48141280163397),
],
}


Expand All @@ -122,6 +128,7 @@ def _test_text_generation(
max_input_tokens: int = 0,
max_output_tokens: int = 100,
parallel_strategy: str = None,
contrastive_search: bool = False,
):
command = ["python3"]
path_to_example_dir = Path(__file__).resolve().parent.parent / "examples"
Expand Down Expand Up @@ -177,6 +184,9 @@ def _test_text_generation(
if not deepspeed:
command.append("--bf16")

if contrastive_search:
command += ["--top_k 4", "--penalty_alpha 0.5"]

if fp8:
if "--trim_logits" not in command:
command += ["--trim_logits"]
Expand Down Expand Up @@ -327,6 +337,13 @@ def test_text_generation_distributed_tp(model_name: str, baseline: float, token:
)


@pytest.mark.parametrize("model_name, batch_size, reuse_cache, baseline", MODELS_TO_TEST["contrastive_search"])
def test_text_generation_contrastive_search(
model_name: str, baseline: float, batch_size: int, reuse_cache: bool, token: str
):
_test_text_generation(model_name, baseline, token, batch_size, reuse_cache, contrastive_search=True)


class TextGenPipeline(TestCase):
def test_text_generation_pipeline_script(self):
path_to_script = (
Expand Down
42 changes: 42 additions & 0 deletions tests/transformers/tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,48 @@ def test_contrastive_generate_low_memory(self):

return

def test_contrastive_generate_dynamic_shapes(self):
# Check that choosing dynamic shapes does not change the model output
for model_class in self.all_generative_model_classes:
# won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format).
if any(
model_name in model_class.__name__.lower()
for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"]
):
return

config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)

# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return

config.use_cache = True
config.is_decoder = True

# test output equality of dynamic vs. static shapes
model = model_class(config).to(torch_device).eval()
model.generation_config.static_shapes = False
dynamic_output = model.generate(
input_ids,
top_k=4,
penalty_alpha=0.6,
max_length=max_length,
attention_mask=attention_mask,
)

model.generation_config.static_shapes = True
static_output = model.generate(
input_ids,
top_k=4,
penalty_alpha=0.6,
max_length=max_length,
attention_mask=attention_mask,
)
self.assertListEqual(dynamic_output.tolist(), static_output.tolist())

return

@pytest.mark.skip(reason="Assisted decoding not yet supported by optimum-habana")
@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
def test_assisted_decoding_matches_greedy_search(self):
Expand Down
26 changes: 13 additions & 13 deletions tests/transformers/tests/models/gptj/test_modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def test_contrastive_search_gptj(self):

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
model = GPTJForCausalLM.from_pretrained(
"EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16
"EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.bfloat16
).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)

Expand All @@ -641,17 +641,17 @@ def test_contrastive_search_gptj(self):
generated_text,
[
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
"United Kingdom with offices in Mountain View, San Francisco, New York City, Paris, Tokyo, Seoul, "
"Beijing, Singapore, Tel Aviv, Dublin, Sydney, and Melbourne.[1]\n\nContents\n\nIn 2010, Google's "
"parent company, Alphabet, announced a $500 million investment in DeepMind, with the aim of creating "
"a company that would apply deep learning to problems in healthcare, energy, transportation, and "
"other areas.[2]\n\nOn April 23, 2014, Google announced that it had acquired DeepMind for $400 "
"million in cash and stock.[3] The acquisition was seen as a way for Google to enter the "
"fast-growing field of artificial intelligence (AI), which it had so far avoided due to concerns "
'about ethical and social implications.[4] Google co-founder Sergey Brin said that he was "thrilled" '
'to have acquired DeepMind, and that it would "help us push the boundaries of AI even further."'
"[5]\n\nDeepMind's founders, Demis Hassabis and Mustafa Suleyman, were joined by a number of Google "
"employees"
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London "
"and has offices in New York, San Francisco, Cambridge, London, Paris, Tokyo, Beijing, Seoul, "
"Singapore, Sydney, and Mountain View.[1]\n\nContents\n\nIn 2010, Google's parent company, "
"Alphabet, announced a $500 million investment in DeepMind, with the aim of creating a company that "
"would apply deep learning to problems in healthcare, energy, transportation, and other areas.[2] "
"The investment was led by Founders Fund, a venture capital firm that invests in early-stage "
"start-ups, and the London-based venture capital firm Atomico.[3]\n\nOn April 23, 2014, Google "
"announced that it had acquired DeepMind for $400 million in cash and stock.[4] The acquisition was "
"seen as a way for Google to gain access to the company's expertise in machine learning and "
"artificial intelligence (AI), which it could apply to a range of products and services at Google.[5] "
'Google CEO Larry Page said that the acquisition would "make Google a leader in this new field and '
"help answer some of the most challenging questions we face as a society—"
],
)

0 comments on commit 4261160

Please sign in to comment.