Skip to content

Commit

Permalink
improved sharded performance and fixed issue with lmhead on rocm (#2008)
Browse files Browse the repository at this point in the history
* improved sharded performance and fixed issue with lmhead on rocm

* mmap shards + disable sharing of device arrays across devices

* fix device_idx for non-layer vmfbs

* fix time calc for sharded

---------

Co-authored-by: Elias Joseph <elias@nod-labs.com>
Co-authored-by: PhaneeshB <b.phaneesh@gmail.com>
  • Loading branch information
3 people authored Dec 5, 2023
1 parent 6384780 commit dfdd3b1
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 51 deletions.
100 changes: 65 additions & 35 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import time
from dataclasses import dataclass
from os import environ
from dataclasses import dataclass
from os import environ

import torch
import torch_mlir
Expand Down Expand Up @@ -510,6 +512,8 @@ def __init__(
n_devices=None,
) -> None:
self.hf_auth_token = hf_auth_token
self.hidden_state_size_dict = {"vicuna": 4096, "llama2_7b": 4096, "llama2_13b" : 5120}
self.n_layers_dict = {"vicuna": 32, "llama2_7b": 32, "llama2_13b" : 40}
super().__init__(
model_name,
hf_model_path,
Expand Down Expand Up @@ -711,6 +715,27 @@ def get_device_index(self, layer_string):
device_idx = max(idx_votes, key=idx_votes.get)
return device_idx


def write_dynamic_inputs_lmhead(self, ir, sample_input_length):
if self.precision in ["fp16", "int4"]:
precision_str = "f16"
else:
precision_str = "f32"
lines = ir.splitlines()
new_lines = []
for line in lines:
if f"%cst_0 =" in line:
new_lines.append(line)
new_lines.append("%c1 = arith.constant 1 : index")
new_lines.append(f"%dim = tensor.dim %arg0, %c1 : tensor<1x?x{self.hidden_state_size_dict[self.model_name]}x{precision_str}>")
else:
line = re.sub(f"{sample_input_length}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
new_lines.append(line)

return "\n".join(new_lines)

def compile_lmhead(
self,
lmh,
Expand Down Expand Up @@ -775,14 +800,21 @@ def compile_lmhead(
use_tracing=False,
verbose=False,
)

"""
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
"""
module = str(module)
if self.precision in ["int4", "fp16"]:
module = self.write_dynamic_inputs_lmhead(module, 137)
filepath = Path(f"{self.dir_name}/lmhead.mlir")
f_ = open(mlir_path, "w+")
f_.write(module)
f_.close()
# download_public_file(
# "gs://shark_tank/elias/compressed_sv/lmhead.mlir",
# filepath.absolute(),
Expand All @@ -795,7 +827,7 @@ def compile_lmhead(
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
mmap=False,
mmap=True,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
Expand Down Expand Up @@ -883,7 +915,7 @@ def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None):
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
mmap=False,
mmap=True,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
Expand Down Expand Up @@ -964,7 +996,7 @@ def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None):
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
mmap=False,
mmap=True,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
Expand Down Expand Up @@ -1163,12 +1195,13 @@ def compile_to_vmfb_one_model(
device_idx = idx % self.n_devices
else:
device_idx = None
print(device_idx, self.n_devices)
module = SharkInference(
None,
device=device,
device_idx=device_idx,
mlir_dialect="tm_tensor",
mmap=False,
mmap=True,
)
module.load_module(vmfb_path)
else:
Expand All @@ -1180,13 +1213,13 @@ def compile_to_vmfb_one_model(
if self.n_devices is not None:
device_idx = idx % self.n_devices
else:
device_idx = 0
device_idx = None
module = SharkInference(
mlirs[idx],
device=device,
device_idx=device_idx,
mlir_dialect="tm_tensor",
mmap=False,
mmap=True,
)
module.save_module(
module_name=f"{self.dir_name}/{idx}_full",
Expand Down Expand Up @@ -1238,7 +1271,7 @@ def compile_to_vmfb_one_model4(
if self.n_devices is not None:
device_idx = idx % self.n_devices
else:
device_idx = 0
device_idx = None
module = SharkInference(
None,
device=device,
Expand All @@ -1256,7 +1289,7 @@ def compile_to_vmfb_one_model4(
if self.n_devices is not None:
device_idx = idx % self.n_devices
else:
device_idx = 0
device_idx = None
module = SharkInference(
mlirs[idx],
device=device,
Expand Down Expand Up @@ -1320,41 +1353,42 @@ def get_sharded_model(self, device="cpu", compressed=False):

placeholder_pkv_segment = tuple(
(
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]),
)
for _ in range(8)
)
placeholder_pkv_full = tuple(
(
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]),
)
for _ in range(32)
for _ in range(self.n_layers_dict[self.model_name])
)

placeholder_input0 = (
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
torch.zeros([1, SAMPLE_INPUT_LEN, self.hidden_state_size_dict[self.model_name]]),
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
)

placeholder_input1 = (
torch.zeros([1, 1, 4096]),
torch.zeros([1, 1, self.hidden_state_size_dict[self.model_name]]),
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
torch.zeros([1, 1], dtype=torch.int64),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]),
)

norm = VicunaNorm(vicuna_model.model.norm)
device_idx = self.get_device_index(
r"vicuna\.model\.model\.norm(?:\.|\s|$)"
)
print(device_idx)
# HC device_idx for non-layer vmfbs
device_idx = 0
norm = self.compile_norm(
norm,
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
torch.zeros([1, SAMPLE_INPUT_LEN, self.hidden_state_size_dict[self.model_name]]),
device=self.device,
device_idx=device_idx,
)
Expand All @@ -1363,7 +1397,8 @@ def get_sharded_model(self, device="cpu", compressed=False):
device_idx = self.get_device_index(
r"vicuna\.model\.model\.embed_tokens(?:\.|\s|$)"
)
print(device_idx)
# HC device_idx for non-layer vmfbs
device_idx = 0
embeddings = self.compile_embedding(
embeddings,
(torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64)),
Expand All @@ -1375,10 +1410,11 @@ def get_sharded_model(self, device="cpu", compressed=False):
device_idx = self.get_device_index(
r"vicuna\.model\.lm_head(?:\.|\s|$)"
)
print(device_idx)
# HC device_idx for non-layer vmfbs
device_idx = 0
lmhead = self.compile_lmhead(
lmhead,
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
torch.zeros([1, SAMPLE_INPUT_LEN, self.hidden_state_size_dict[self.model_name]]),
device=self.device,
device_idx=device_idx,
)
Expand Down Expand Up @@ -1452,13 +1488,13 @@ def generate(self, prompt, cli=False):

generated_token_op = self.generate_new_token(params=params)

prefill_time = time.time() - decode_st_time
decode_time = (time.time() - decode_st_time) * 1000

_token = generated_token_op["token"]
_past_key_values = generated_token_op["past_key_values"]
_detok = generated_token_op["detok"]
history.append(_token)
yield self.tokenizer.decode(history), None, prefill_time
yield self.tokenizer.decode(history), None, decode_time

if _token == 2:
break
Expand Down Expand Up @@ -1667,12 +1703,7 @@ def remove_constant_dim(line):
new_lines = []

# Using a while loop and the pop method to avoid creating a copy of module
if "llama2_13b" in self.model_name:
pkv_tensor_shape = "tensor<1x40x?x128x"
elif "llama2_70b" in self.model_name:
pkv_tensor_shape = "tensor<1x8x?x128x"
else:
pkv_tensor_shape = "tensor<1x32x?x128x"
pkv_tensor_shape = f"tensor<1x{self.n_layers_dict[self.model_name]}x?x128x"
if self.precision in ["fp16", "int4", "int8"]:
pkv_tensor_shape += "f16>"
else:
Expand Down Expand Up @@ -2066,14 +2097,14 @@ def generate(self, prompt, cli):
generated_token_op = self.generate_new_token(
params=params, sharded=False, cli=cli
)
prefill_time = time.time() - prefill_st_time
prefill_time_ms = (time.time() - prefill_st_time) * 1000

token = generated_token_op["token"]
if "cpu" not in self.device:
logits = generated_token_op["logits"]
pkv = generated_token_op["past_key_values"]
detok = generated_token_op["detok"]
yield detok, None, prefill_time
yield detok, None, prefill_time_ms

res_tokens.append(token)
if cli:
Expand Down Expand Up @@ -2408,8 +2439,7 @@ def avg_and_stdev(data):
vic.shark_model.shark_runner.iree_config.device.flush_profiling()
if msg is None:
if is_first:
# Note that the prefill time is in seconds, and all the decoded tokens in ms.
prefill_time_ms = exec_time * 1000
prefill_time_ms = exec_time
is_first = False
else:
token_times_ms.append(exec_time)
Expand Down
41 changes: 30 additions & 11 deletions apps/language_models/src/model_wrappers/vicuna_sharded_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import time


class FirstVicunaLayer(torch.nn.Module):
Expand Down Expand Up @@ -110,9 +111,11 @@ def __init__(self, shark_module):
self.model = shark_module

def forward(self, hidden_states):
hidden_states = hidden_states.detach()
hidden_states_sample = hidden_states.detach()

output = self.model("forward", (hidden_states,))
output = torch.tensor(output)

return output


Expand All @@ -136,8 +139,9 @@ def forward(self, hidden_states):
hidden_states.detach()
except:
pass
output = self.model("forward", (hidden_states,))
output = self.model("forward", (hidden_states,), send_to_host=True)
output = torch.tensor(output)

return output


Expand All @@ -158,8 +162,9 @@ def __init__(self, shark_module):

def forward(self, input_ids):
input_ids.detach()
output = self.model("forward", (input_ids,))
output = self.model("forward", (input_ids,), send_to_host=True)
output = torch.tensor(output)

return output


Expand All @@ -178,21 +183,28 @@ def forward(
use_cache=True,
):
if past_key_value is None:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
# hidden_states = hidden_states.detach()
# attention_mask = attention_mask.detach()
# position_ids = position_ids.detach()

output = self.model(
"first_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
),
send_to_host=True,
)

### send_to_host=True
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
### send_to_host=False
# output0 = output[0]
# output1 = output[1]
# output2 = output[2]

return (
output0,
Expand All @@ -202,11 +214,12 @@ def forward(
),
)
else:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
# hidden_states = hidden_states.detach()
# attention_mask = attention_mask.detach()
# position_ids = position_ids.detach()
# pkv0 = past_key_value[0].detach()
pkv0 = past_key_value[0]
pkv1 = past_key_value[1]
output = self.model(
"second_vicuna_forward",
(
Expand All @@ -216,11 +229,17 @@ def forward(
pkv0,
pkv1,
),
send_to_host=True,
)

### send_to_host=True
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
### send_to_host=False
# output0 = output[0]
# output1 = output[1]
# output2 = output[2]

return (
output0,
Expand Down
Loading

0 comments on commit dfdd3b1

Please sign in to comment.