Skip to content

Commit

Permalink
Add VBE to Dense TBE frontend (#2628)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2628

- add frontend support to Dense TBE module to support VBE
- add unit test

Differential Revision: D56651380
  • Loading branch information
joshuadeng authored and facebook-github-bot committed May 23, 2024
1 parent d5c94f1 commit b31daed
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 126 deletions.
278 changes: 165 additions & 113 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,126 @@ def apply_split_helper(
)


def generate_vbe_metadata(
offsets: Tensor,
batch_size_per_feature_per_rank: Optional[List[List[int]]],
optimizer: OptimType,
pooling_mode: PoolingMode,
feature_dims: Tensor,
device: torch.device,
) -> invokers.lookup_args.VBEMetadata:
"""
Generate VBE metadata based on batch_size_per_feature_per_rank.
Metadata includes:
1) B_offsets - A tensor that contains batch size offsets for each
feature
2) output_offsets_feature_rank - A tensor that contains output
offsets for each feature
3) B_offsets_per_rank_per_feature - A tensor that contains batch
size offsets for each feature
and rank
4) max_B - The maximum batch size for all features
5) max_B_feature_rank - The maximum batch size for all ranks and
features
6) output_size - The output size (number of elements)
"""
if batch_size_per_feature_per_rank is not None:
assert (
optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
or optimizer == OptimType.EXACT_SGD
or optimizer == OptimType.NONE
), "Variable batch size TBE support is enabled for OptimType.EXACT_ROWWISE_ADAGRAD only"
assert (
pooling_mode != PoolingMode.NONE
), "Variable batch size TBE support is not enabled for PoolingMode.NONE"
# TODO: Add input check
zero_tensor = torch.zeros(1, device="cpu", dtype=torch.int32)

# Create B offsets
total_batch_size_per_feature = torch.tensor(
batch_size_per_feature_per_rank, dtype=torch.int32, device="cpu"
).sum(dim=1)

max_B = total_batch_size_per_feature.max().item()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check_is_size(max_B)
torch._check(max_B < offsets.numel())

Bs = torch.concat([zero_tensor, total_batch_size_per_feature])
B_offsets = Bs.cumsum(dim=0).to(torch.int)

# Create output offsets
B_feature_rank = torch.tensor(
batch_size_per_feature_per_rank,
device="cpu",
dtype=torch.int64,
)
max_B_feature_rank = B_feature_rank.max().item()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check_is_size(max_B_feature_rank)
torch._check(max_B_feature_rank <= offsets.size(0))
# D->H only once
feature_dims = feature_dims.cpu()
output_sizes_feature_rank = B_feature_rank.transpose(0, 1) * feature_dims.view(
1, -1
)
output_offsets_feature_rank = torch.concat(
[
zero_tensor.to(torch.int64),
output_sizes_feature_rank.flatten().cumsum(dim=0),
]
)
output_size = output_offsets_feature_rank[-1].item()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check_is_size(output_size)

# TODO: Support INT8 output
# B_offsets_rank_per_feature is for rank and (b, t) mapping
B_offsets_rank_per_feature = (
torch.tensor(
[
[0] + batch_size_per_feature
for batch_size_per_feature in batch_size_per_feature_per_rank
],
device="cpu",
dtype=torch.int32,
)
.cumsum(dim=1)
.to(torch.int)
)

B_offsets = B_offsets.to(device, non_blocking=True)
output_offsets_feature_rank = output_offsets_feature_rank.to(
device, non_blocking=True
)
B_offsets_rank_per_feature = B_offsets_rank_per_feature.to(
device, non_blocking=True
)

# TODO: Use int32 for B_offsets and int64 for output_offsets_feature_rank
vbe_metadata = invokers.lookup_args.VBEMetadata(
B_offsets=B_offsets,
output_offsets_feature_rank=output_offsets_feature_rank,
B_offsets_rank_per_feature=B_offsets_rank_per_feature,
# pyre-ignore
max_B=max_B,
# pyre-ignore
max_B_feature_rank=max_B_feature_rank,
# pyre-ignore
output_size=output_size,
)
else:
vbe_metadata = invokers.lookup_args.VBEMetadata(
B_offsets=None,
output_offsets_feature_rank=None,
B_offsets_rank_per_feature=None,
max_B=-1,
max_B_feature_rank=-1,
output_size=-1,
)
return vbe_metadata


# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
# pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
Expand Down Expand Up @@ -1092,120 +1212,20 @@ def _report_io_size_count(self, event: str, data: Tensor) -> Tensor:
)
return data

@torch.jit.ignore
def _generate_vbe_metadata(
self,
offsets: Tensor,
batch_size_per_feature_per_rank: Optional[List[List[int]]],
) -> invokers.lookup_args.VBEMetadata:
"""
Generate VBE metadata based on batch_size_per_feature_per_rank.
Metadata includes:
1) B_offsets - A tensor that contains batch size offsets for each
feature
2) output_offsets_feature_rank - A tensor that contains output
offsets for each feature
3) B_offsets_per_rank_per_feature - A tensor that contains batch
size offsets for each feature
and rank
4) max_B - The maximum batch size for all features
5) max_B_feature_rank - The maximum batch size for all ranks and
features
6) output_size - The output size (number of elements)
"""
if batch_size_per_feature_per_rank is not None:
assert (
self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
or self.optimizer == OptimType.EXACT_SGD
), "Variable batch size TBE support is enabled for OptimType.EXACT_ROWWISE_ADAGRAD only"
assert (
self.pooling_mode != PoolingMode.NONE.value
), "Variable batch size TBE support is not enabled for PoolingMode.NONE"
# TODO: Add input check
zero_tensor = torch.zeros(1, device="cpu", dtype=torch.int32)

# Create B offsets
total_batch_size_per_feature = torch.tensor(
batch_size_per_feature_per_rank, dtype=torch.int32, device="cpu"
).sum(dim=1)

max_B = total_batch_size_per_feature.max().item()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check_is_size(max_B)
torch._check(max_B < offsets.numel())

Bs = torch.concat([zero_tensor, total_batch_size_per_feature])
B_offsets = Bs.cumsum(dim=0).to(torch.int)

# Create output offsets
B_feature_rank = torch.tensor(
batch_size_per_feature_per_rank,
device="cpu",
dtype=torch.int64,
)
max_B_feature_rank = B_feature_rank.max().item()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check_is_size(max_B_feature_rank)
torch._check(max_B_feature_rank <= offsets.size(0))
# D->H only once
self.feature_dims = self.feature_dims.cpu()
output_sizes_feature_rank = B_feature_rank.transpose(
0, 1
) * self.feature_dims.view(1, -1)
output_offsets_feature_rank = torch.concat(
[
zero_tensor.to(torch.int64),
output_sizes_feature_rank.flatten().cumsum(dim=0),
]
)
output_size = output_offsets_feature_rank[-1].item()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check_is_size(output_size)

# TODO: Support INT8 output
# B_offsets_rank_per_feature is for rank and (b, t) mapping
B_offsets_rank_per_feature = (
torch.tensor(
[
[0] + batch_size_per_feature
for batch_size_per_feature in batch_size_per_feature_per_rank
],
device="cpu",
dtype=torch.int32,
)
.cumsum(dim=1)
.to(torch.int)
)

B_offsets = B_offsets.to(self.current_device, non_blocking=True)
output_offsets_feature_rank = output_offsets_feature_rank.to(
self.current_device, non_blocking=True
)
B_offsets_rank_per_feature = B_offsets_rank_per_feature.to(
self.current_device, non_blocking=True
)

# TODO: Use int32 for B_offsets and int64 for output_offsets_feature_rank
vbe_metadata = invokers.lookup_args.VBEMetadata(
B_offsets=B_offsets,
output_offsets_feature_rank=output_offsets_feature_rank,
B_offsets_rank_per_feature=B_offsets_rank_per_feature,
# pyre-ignore
max_B=max_B,
# pyre-ignore
max_B_feature_rank=max_B_feature_rank,
# pyre-ignore
output_size=output_size,
)
else:
vbe_metadata = invokers.lookup_args.VBEMetadata(
B_offsets=None,
output_offsets_feature_rank=None,
B_offsets_rank_per_feature=None,
max_B=-1,
max_B_feature_rank=-1,
output_size=-1,
)
return vbe_metadata
return generate_vbe_metadata(
offsets,
batch_size_per_feature_per_rank,
self.optimizer,
self.pooling_mode,
self.feature_dims,
self.current_device,
)

def forward( # noqa: C901
self,
Expand Down Expand Up @@ -1571,8 +1591,7 @@ def prefetch(
self.prefetch_stream != forward_stream
), "prefetch_stream and forward_stream should not be the same stream"
vbe_metadata = self._generate_vbe_metadata(
offsets,
batch_size_per_feature_per_rank,
offsets, batch_size_per_feature_per_rank
)
self._prefetch(
indices,
Expand Down Expand Up @@ -2473,15 +2492,21 @@ def __init__(
)
T = len(feature_table_map)
assert T_ <= T
D_offsets = [dims[t] for t in feature_table_map]
D_offsets = [0] + list(accumulate(D_offsets))

feature_dims = [dims[t] for t in feature_table_map]
D_offsets = [0] + list(accumulate(feature_dims))
self.total_D = D_offsets[-1]
self.max_D = max(dims)
self.register_buffer(
"D_offsets",
torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32),
)
assert self.D_offsets.numel() == T + 1
# Required for VBE
self.register_buffer(
"feature_dims",
torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
)

hash_size_cumsum = [0] + list(accumulate(rows))
if hash_size_cumsum[-1] == 0:
Expand Down Expand Up @@ -2536,13 +2561,34 @@ def __init__(
),
)

@torch.jit.ignore
def _generate_vbe_metadata(
self,
offsets: Tensor,
batch_size_per_feature_per_rank: Optional[List[List[int]]],
) -> invokers.lookup_args.VBEMetadata:
return generate_vbe_metadata(
offsets,
batch_size_per_feature_per_rank,
OptimType.NONE,
self.pooling_mode,
self.feature_dims,
self.current_device,
)

def forward(
self,
indices: Tensor,
offsets: Tensor,
per_sample_weights: Optional[Tensor] = None,
feature_requires_grad: Optional[Tensor] = None,
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
) -> Tensor:
# Generate VBE metadata
vbe_metadata = self._generate_vbe_metadata(
offsets, batch_size_per_feature_per_rank
)

(indices, offsets) = indices.long(), offsets.long()
# Force casting per_sample_weights to float
if per_sample_weights is not None:
Expand All @@ -2562,6 +2608,12 @@ def forward(
indice_weights=per_sample_weights,
feature_requires_grad=feature_requires_grad,
output_dtype=self.output_dtype,
B_offsets=vbe_metadata.B_offsets,
vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
max_B=vbe_metadata.max_B,
max_B_feature_rank=vbe_metadata.max_B_feature_rank,
vbe_output_size=vbe_metadata.output_size,
)

@torch.jit.export
Expand Down
Loading

0 comments on commit b31daed

Please sign in to comment.