Skip to content

Commit

Permalink
fix alignment up to input ln
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Aug 20, 2024
1 parent 9ca3687 commit d0e98ec
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tests/peft/peft_alignment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def get_hf_tensor(hf_tensor_name, tensor_comparison_idx):
hf_tensor = torch.load(hf_tensor_path, map_location='cpu')
return hf_tensor

def get_ff_tensor(ff_tensor_name, tensor_comparison_idx, hf_shape, tp_type=TPType.REPLICATE, pre=False):
def get_ff_tensor(ff_tensor_name, tensor_comparison_idx, hf_shape, tp_type=TPType.REPLICATE, pre=False, shard_axis=0):
ff_tensor_suffix = f".{tensor_comparison_idx.ff_tensor_type}" if len(tensor_comparison_idx.ff_tensor_type) > 0 else ""
ff_tensor_idx_suffix = f"_{tensor_comparison_idx.ff_tensor_idx}" if tensor_comparison_idx.ff_tensor_idx is not None else ""
ff_tensor_filename = f"{ff_tensor_name}{ff_tensor_suffix}{ff_tensor_idx_suffix}"
Expand All @@ -381,7 +381,7 @@ def get_ff_tensor(ff_tensor_name, tensor_comparison_idx, hf_shape, tp_type=TPTyp

ff_shape = list(hf_shape)[::-1]
if tp_type == TPType.PARTITION:
ff_shape[0] //= self.tp_degree
ff_shape[shard_axis] //= self.tp_degree

# exception: intermediate attention tensors
intermediate_attention_tensor = (
Expand All @@ -405,10 +405,10 @@ def get_ff_tensor(ff_tensor_name, tensor_comparison_idx, hf_shape, tp_type=TPTyp
ff_tensor = ff_tensors[0]
# if partition, concatenate along the partition dimension
elif tp_type == TPType.PARTITION:
ff_tensor = np.concatenate(ff_tensors, axis=0)
ff_tensor = np.concatenate(ff_tensors, axis=shard_axis)
# if to_reduce, sum along the partition dimension
elif tp_type == TPType.TO_REDUCE:
ff_tensor = np.sum(ff_tensors, axis=0)
ff_tensor = np.sum(ff_tensors, axis=shard_axis)
else:
ff_tensor = ff_tensors[0]
ff_tensor = torch.from_numpy(ff_tensor)
Expand Down Expand Up @@ -551,7 +551,7 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance
mixed_comparison = TensorComparisonIdxs(hf_tensor_type="output_gradient", ff_tensor_type="input_gradient", hf_tensor_idx=0, ff_tensor_idx=0)
hf_tensor = get_hf_tensor(hf_tensor_name, mixed_comparison)
hf_tensor = hf_tensor.squeeze().T
ff_tensor = get_ff_tensor(ff_tensor_name, mixed_comparison, hf_tensor.shape, tp_type=TPType.PARTITION)
ff_tensor = get_ff_tensor(ff_tensor_name, mixed_comparison, hf_tensor.shape, tp_type=TPType.PARTITION, shard_axis=1)
compare(hf_tensor, ff_tensor, label=f"V-proj {i} gradient input")

# K-proj grads
Expand All @@ -562,7 +562,7 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance
hf_tensor = get_hf_tensor(hf_tensor_name, k_proj_comparison)
hf_tensor = hf_tensor.squeeze().view(self.num_tokens, self.num_attention_heads, self.projsize).transpose(1, 2).contiguous()
hf_tensor = hf_tensor.T
ff_tensor = get_ff_tensor(ff_tensor_name, k_proj_comparison, hf_tensor.shape, tp_type=TPType.PARTITION)
ff_tensor = get_ff_tensor(ff_tensor_name, k_proj_comparison, hf_tensor.shape, tp_type=TPType.PARTITION, shard_axis=2)
compare(hf_tensor, ff_tensor, label=f"K-proj {i} gradient input")

# Q-proj grads
Expand All @@ -574,15 +574,15 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance
hf_tensor = get_hf_tensor(hf_tensor_name, q_proj_comparison)
hf_tensor = hf_tensor.view(self.num_tokens, self.num_attention_heads, self.projsize).transpose(1, 2).contiguous().T
augmented_hf_tensor_shape = torch.Size([3]+list(hf_tensor.size()))
ff_tensor = get_ff_tensor(ff_tensor_name, q_proj_comparison, augmented_hf_tensor_shape, tp_type=TPType.PARTITION)[:,:,:,0]
ff_tensor = get_ff_tensor(ff_tensor_name, q_proj_comparison, augmented_hf_tensor_shape, tp_type=TPType.PARTITION, shard_axis=2)[:,:,:,0]
compare(hf_tensor, ff_tensor, label=f"Q-proj {i} gradient input")

# FF Attn input with HF layernorm out
hf_tensor_name = f"layers.{i}.input_layernorm"
ff_tensor_name = f"layers.{i}.layers.{i}.self_attn"
input_comparison = TensorComparisonIdxs(hf_tensor_type="output_gradient", ff_tensor_type="input_gradient", hf_tensor_idx=0, ff_tensor_idx=0)
hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison)
ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.PARTITION)
ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE)
compare(hf_tensor, ff_tensor, label=f"Attn input {i} gradient input")

if i > 0:
Expand Down

0 comments on commit d0e98ec

Please sign in to comment.