diff --git a/tests/peft/peft_alignment_test.py b/tests/peft/peft_alignment_test.py index 5843ffa3d9..4e8166e2a9 100644 --- a/tests/peft/peft_alignment_test.py +++ b/tests/peft/peft_alignment_test.py @@ -368,7 +368,7 @@ def get_hf_tensor(hf_tensor_name, tensor_comparison_idx): hf_tensor = torch.load(hf_tensor_path, map_location='cpu') return hf_tensor - def get_ff_tensor(ff_tensor_name, tensor_comparison_idx, hf_shape, tp_type=TPType.REPLICATE, pre=False): + def get_ff_tensor(ff_tensor_name, tensor_comparison_idx, hf_shape, tp_type=TPType.REPLICATE, pre=False, shard_axis=0): ff_tensor_suffix = f".{tensor_comparison_idx.ff_tensor_type}" if len(tensor_comparison_idx.ff_tensor_type) > 0 else "" ff_tensor_idx_suffix = f"_{tensor_comparison_idx.ff_tensor_idx}" if tensor_comparison_idx.ff_tensor_idx is not None else "" ff_tensor_filename = f"{ff_tensor_name}{ff_tensor_suffix}{ff_tensor_idx_suffix}" @@ -381,7 +381,7 @@ def get_ff_tensor(ff_tensor_name, tensor_comparison_idx, hf_shape, tp_type=TPTyp ff_shape = list(hf_shape)[::-1] if tp_type == TPType.PARTITION: - ff_shape[0] //= self.tp_degree + ff_shape[shard_axis] //= self.tp_degree # exception: intermediate attention tensors intermediate_attention_tensor = ( @@ -405,10 +405,10 @@ def get_ff_tensor(ff_tensor_name, tensor_comparison_idx, hf_shape, tp_type=TPTyp ff_tensor = ff_tensors[0] # if partition, concatenate along the partition dimension elif tp_type == TPType.PARTITION: - ff_tensor = np.concatenate(ff_tensors, axis=0) + ff_tensor = np.concatenate(ff_tensors, axis=shard_axis) # if to_reduce, sum along the partition dimension elif tp_type == TPType.TO_REDUCE: - ff_tensor = np.sum(ff_tensors, axis=0) + ff_tensor = np.sum(ff_tensors, axis=shard_axis) else: ff_tensor = ff_tensors[0] ff_tensor = torch.from_numpy(ff_tensor) @@ -551,7 +551,7 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance mixed_comparison = TensorComparisonIdxs(hf_tensor_type="output_gradient", ff_tensor_type="input_gradient", hf_tensor_idx=0, ff_tensor_idx=0) hf_tensor = get_hf_tensor(hf_tensor_name, mixed_comparison) hf_tensor = hf_tensor.squeeze().T - ff_tensor = get_ff_tensor(ff_tensor_name, mixed_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + ff_tensor = get_ff_tensor(ff_tensor_name, mixed_comparison, hf_tensor.shape, tp_type=TPType.PARTITION, shard_axis=1) compare(hf_tensor, ff_tensor, label=f"V-proj {i} gradient input") # K-proj grads @@ -562,7 +562,7 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance hf_tensor = get_hf_tensor(hf_tensor_name, k_proj_comparison) hf_tensor = hf_tensor.squeeze().view(self.num_tokens, self.num_attention_heads, self.projsize).transpose(1, 2).contiguous() hf_tensor = hf_tensor.T - ff_tensor = get_ff_tensor(ff_tensor_name, k_proj_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + ff_tensor = get_ff_tensor(ff_tensor_name, k_proj_comparison, hf_tensor.shape, tp_type=TPType.PARTITION, shard_axis=2) compare(hf_tensor, ff_tensor, label=f"K-proj {i} gradient input") # Q-proj grads @@ -574,7 +574,7 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance hf_tensor = get_hf_tensor(hf_tensor_name, q_proj_comparison) hf_tensor = hf_tensor.view(self.num_tokens, self.num_attention_heads, self.projsize).transpose(1, 2).contiguous().T augmented_hf_tensor_shape = torch.Size([3]+list(hf_tensor.size())) - ff_tensor = get_ff_tensor(ff_tensor_name, q_proj_comparison, augmented_hf_tensor_shape, tp_type=TPType.PARTITION)[:,:,:,0] + ff_tensor = get_ff_tensor(ff_tensor_name, q_proj_comparison, augmented_hf_tensor_shape, tp_type=TPType.PARTITION, shard_axis=2)[:,:,:,0] compare(hf_tensor, ff_tensor, label=f"Q-proj {i} gradient input") # FF Attn input with HF layernorm out @@ -582,7 +582,7 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance ff_tensor_name = f"layers.{i}.layers.{i}.self_attn" input_comparison = TensorComparisonIdxs(hf_tensor_type="output_gradient", ff_tensor_type="input_gradient", hf_tensor_idx=0, ff_tensor_idx=0) hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) - ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE) compare(hf_tensor, ff_tensor, label=f"Attn input {i} gradient input") if i > 0: