Skip to content

Commit

Permalink
remove uncessary .contiguous() in fp8 backward
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Nov 1, 2024
1 parent b4156dc commit 39a4960
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/nanotron/fp8/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,12 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
assert grad_weight.dtype == recipe.accum_dtype
# TODO(xrsrke): maintain a persistence metadata across training

grad_weight = grad_weight.T.contiguous()
# grad_weight = grad_weight.T.contiguous()
# orig_shape = grad_weight.shape
# grad_weight = grad_weight.contiguous().t().contiguous().view(-1).contiguous().reshape(orig_shape)
grad_weight = grad_weight.T
orig_shape = grad_weight.shape
grad_weight = grad_weight.contiguous().t().contiguous().view(-1).contiguous().reshape(orig_shape)
grad_weight = grad_weight.t().view(-1).reshape(orig_shape)

# NOTE: if use gradient accumulation, then directly keep the high precision weights for later accumulate
if constants.CONFIG is not None and (
Expand Down

0 comments on commit 39a4960

Please sign in to comment.