Skip to content

Commit

Permalink
Merge branch 'main' into ig/pytest-diff-tests-fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
imangohari1 committed Aug 14, 2024
2 parents 3f79d4b + fe62fcb commit f25a606
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 57 deletions.
27 changes: 27 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,33 @@ python run_generation.py \
```


### Loading 4 Bit Checkpoints from Hugging Face

You can load pre-quantized 4bit models with the argument `--load_quantized_model`.
Currently, uint4 checkpoints and single device are supported.
More information on enabling 4 bit inference in SynapseAI is available here:
https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_INT4.html.

Below is an example to load a model with 4bit checkpoints from Hugging Face.
Please note that model name is denoted as `<model_path_in_hugging_face>`.
Additionally, the below env vars are used for performance optimizations, and are planned to be removed in future version:
`SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=1`
```bash
SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=1 \
python run_lm_eval.py \
-o acc_load_uint4_model.txt \
--model_name_or_path <model_path_in_hugging_face> \
--use_hpu_graphs \
--use_kv_cache \
--trim_logits \
--batch_size 1 \
--bf16 \
--attn_softmax_bf16 \
--bucket_size=128 \
--bucket_internal \
--load_quantized_model
```

### Using Habana Flash Attention

Habana Flash Attention addresses large sequence lengths on prompt stage of inference. Using causal attention mask on prompt stage requires input sequences in batch to be of the same length, but can provide a memory saving, thus enabling higher batch sizes.
Expand Down
5 changes: 5 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,11 @@ def setup_parser(parser):
action="store_true",
help="Whether to trust the execution of code from datasets/models defined on the Hub. This option should only be set to `True` for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.",
)
parser.add_argument(
"--load_quantized_model",
action="store_true",
help="Whether to load model from hugging face checkpoint.",
)
parser.add_argument(
"--parallel_strategy",
type=str,
Expand Down
7 changes: 7 additions & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ def setup_model(args, model_dtype, model_kwargs, logger):
torch_dtype=model_dtype,
**model_kwargs,
)
elif args.load_quantized_model:
from neural_compressor.torch.quantization import load

model = load(model_name_or_path=args.model_name_or_path, format="huggingface", device="hpu", **model_kwargs)
else:
if args.assistant_model is not None:
assistant_model = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -619,6 +623,9 @@ def initialize_model(args, logger):
"trust_remote_code": args.trust_remote_code,
}

if args.load_quantized_model:
model_kwargs["torch_dtype"] = torch.bfloat16

if args.trust_remote_code:
logger.warning("`trust_remote_code` is set, there is no guarantee this model works properly and it may fail")

Expand Down
27 changes: 15 additions & 12 deletions optimum/habana/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,10 @@ def forward(
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
token_idx: Optional[torch.Tensor] = None,
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
sin: Optional[torch.Tensor] = None,
cos: Optional[torch.Tensor] = None,
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
Expand Down Expand Up @@ -221,19 +221,22 @@ def forward(
key = torch.cat([past_key, key], dim=-2)
value = torch.cat([past_value, value], dim=-2)

if use_cache is True and token_idx is not None:
if use_cache is True:
if reuse_cache:
key = self.k_cache(key, 2, token_idx)
value = self.v_cache(value, 2, token_idx)
present = (self.k_cache.get_shape(), self.v_cache.get_shape())
else:
if layer_past is None:
past_key = torch.zeros(key.shape, dtype=self.k_proj.weight.dtype, device=key.device)
past_value = torch.zeros(key.shape, dtype=self.k_proj.weight.dtype, device=key.device)
layer_past = (past_key, past_value)
key = self.k_cache.update(layer_past[0], key, 2, token_idx, self.inp_seq_len)
value = self.v_cache.update(layer_past[1], value, 2, token_idx, self.inp_seq_len)
present = layer_past
if token_idx is not None:
if layer_past is None:
past_key = torch.zeros(key.shape, dtype=self.k_proj.weight.dtype, device=key.device)
past_value = torch.zeros(key.shape, dtype=self.k_proj.weight.dtype, device=key.device)
layer_past = (past_key, past_value)
key = self.k_cache.update(layer_past[0], key, 2, token_idx, self.inp_seq_len)
value = self.v_cache.update(layer_past[1], value, 2, token_idx, self.inp_seq_len)
present = layer_past
else:
present = (key.to(hidden_states.dtype), value)

if cache_idx is not None and q_len == 1:
key = key[:, :, :cache_idx, :]
Expand Down Expand Up @@ -288,10 +291,10 @@ def forward(
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
token_idx: Optional[torch.Tensor] = None,
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
sin: Optional[torch.Tensor] = None,
cos: Optional[torch.Tensor] = None,
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
"""
Copied from GPTJBlock.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py
Expand Down
25 changes: 20 additions & 5 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,22 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.inp_seq_len = -1
self.norm_factor = 1.0 / math.sqrt(self.head_dim)

def get_k_proj_weight(self):
"""4bit quantization in GPTQ replaces the k_proj.weight with qweight."""
if hasattr(self.k_proj, "qweight"):
return self.k_proj.qweight
return self.k_proj.weight

def get_k_proj_weight_dtype(self):
"""4bit quantization in GPTQ replaces the k_proj.weight with qweight.
Scales tensor gets the weight dtype."""
if hasattr(self.k_proj, "qweight"):
return self.k_proj.scales.dtype
return self.k_proj.weight.dtype

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)
device = self.k_proj.weight.device
device = self.get_k_proj_weight().device
dtype = self.config.torch_dtype
self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape)
self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape)
Expand All @@ -436,7 +449,7 @@ def update_sincos_cache(self, seq_len):
# reduce memory consumption and improve performance.
if seq_len > self.max_position_embeddings:
self.max_position_embeddings = seq_len
_, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len)
_, _ = self.rotary_emb(self.get_k_proj_weight(), seq_len=seq_len)

def reorder(self, tensor, beam_idx, dim_a, dim_b):
updated = tensor.index_select(0, beam_idx)
Expand Down Expand Up @@ -493,7 +506,7 @@ def pre_attn_forward(
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
key_slices = self.get_k_proj_weight().split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
Expand Down Expand Up @@ -565,9 +578,11 @@ def pre_attn_forward(
past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
else:
if past_key_value is None:
past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device)
past_key = torch.zeros(
key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device
)
past_value = torch.zeros(
key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device
key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device
)
# Return list instead of tuple
past_key_value = [past_key, past_value]
Expand Down
Loading

0 comments on commit f25a606

Please sign in to comment.