Skip to content

Commit

Permalink
Test moving doc into TBE file
Browse files Browse the repository at this point in the history
Differential Revision: D61307727
  • Loading branch information
sryap authored and facebook-github-bot committed Aug 15, 2024
1 parent ada1050 commit 63cd1fa
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 9 deletions.
2 changes: 1 addition & 1 deletion fbgemm_gpu/fbgemm_gpu/docs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@

# Trigger the manual addition of docstrings to pybind11-generated operators
try:
from . import jagged_tensor_ops, table_batched_embedding_ops # noqa: F401
from . import jagged_tensor_ops # noqa: F401
except Exception:
pass
178 changes: 170 additions & 8 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,16 +447,178 @@ def generate_vbe_metadata(
# pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
"""
Table Batched Embedding (TBE) operator. Please see
docs/table_batched_embedding_ops.py for the extended documentation.
Table Batched Embedding (TBE) operator. Looks up one or more
embedding tables. The module is application for training. The
backward operator is fused with optimizer. Thus, the embedding
tables are updated during backward.
Multiple sparse features can share one embedding table.
'feature_table_map' specifies the feature-table mapping.
T: number of logical tables
T_: number of physical tables
T >= T_
Args:
embedding_specs (List[Tuple[int, int, EmbeddingLocation,
ComputeDevice]]): A list of embedding specifications. Each
spec is a tuple of (number of embedding rows, embedding
dimension; must be a multiple of 4, table placement, compute
device).
For supported optimizer hyperparams, see inline comments below
feature_table_map (List[int], optional): An optional list that
specifies feature-table mapping.
cache_algorithm (CacheAlgorithm, optional): LXU cache
algorithm (`CacheAlgorithm.LRU`, `CacheAlgorithm.LFU`)
cache_load_factor (float, optional): The LXU cache capacity
which is `cache_load_factor` * the total number of rows in all
embedding tables
cache_sets (int, optional): The number of cache sets
cache_reserved_memory (float, optional): Amount of memory
reserved in HBM for non-cache purpose.
cache_precision (SparseType, optional): Data type of LXU cache
(`SparseType.FP32`, `SparseType.FP16`)
weights_precision (SparseType, optional): Data type of
embedding tables (also known as weights) (`SparseType.FP32`,
`SparseType.FP16`, `SparseType.INT8`)
output_dtype (SparseType, optional): Data type of an output
tensor (`SparseType.FP32`, `SparseType.FP16`,
`SparseType.INT8`)
enforce_hbm (bool, optional): If True, place all
weights/momentums in HBM when using cache
optimizer (OptimType, optional): An optimizer to use for
embedding table update in the backward pass.
(`OptimType.ADAM`, `OptimType.EXACT_ADAGRAD`,
`OptimType.EXACT_ROWWISE_ADAGRAD`, `OptimType.EXACT_SGD`,
`OptimType.LAMB`, `OptimType.LARS_SGD`,
`OptimType.PARTIAL_ROWWISE_ADAM`,
`OptimType.PARTIAL_ROWWISE_LAMB`, `OptimType.SGD`)
record_cache_metrics (RecordCacheMetrics, optional): Record
number of hits, number of requests, etc if
RecordCacheMetrics.record_cache_miss_counter is True and
record the similar metrics table-wise if
RecordCacheMetrics.record_tablewise_cache_miss is True
(default is None).
stochastic_rounding (bool, optional): If True, apply
stochastic rounding for weight type that is not
`SparseType.FP32`
gradient_clipping (bool, optional): If True, apply gradient
clipping
max_gradient (float, optional): The value for gradient
clipping
learning_rate (float, optional): The learning rate
eps (float, optional): The epsilon value used by Adagrad,
LAMB, and Adam
momentum (float, optional): Momentum used by LARS-SGD
weight_decay (float, optional): Weight decay used by LARS-SGD,
LAMB, ADAM, and Rowwise Adagrad
weight_decay_mode (WeightDecayMode, optional): Weight decay
mode (`WeightDecayMode.NONE`, `WeightDecayMode.L2`,
`WeightDecayMode.DECOUPLE`)
eta (float, optional): The eta value used by LARS-SGD
beta1 (float, optional): The beta1 value used by LAMB and ADAM
beta2 (float, optional): The beta2 value used by LAMB and ADAM
pooling_mode (PoolingMode, optional): Pooling mode
(`PoolingMode.SUM`, `PoolingMode.MEAN`, `PoolingMode.NONE`)
device (torch.device, optional): The current device to place
tensors on
bounds_check_mode (BoundsCheckMode, optional): If not set to
`BoundsCheckMode.NONE`, apply boundary check for indices
(`BoundsCheckMode.NONE`, `BoundsCheckMode.FATAL`,
`BoundsCheckMode.WARNING`, `BoundsCheckMode.IGNORE`)
Inputs:
indices (torch.Tensor): A 1D-tensor that contains indices to
be accessed in all embedding table
offsets (torch.Tensor): A 1D-tensor that conatins offsets of
indices. Shape `(B * T + 1)` where `B` = batch size and `T` =
number of tables. `offsets[t * B + b + 1] - offsets[t * B +
b]` is the length of bag `b` of table `t`
per_sample_weights (torch.Tensor, optional): An optional
1D-tensor that contains positional weights. Shape `(max(bag
length))`. Positional weight `i` is multiplied to all columns
of row `i` in each bag after its read from the embedding table
and before pooling (if pooling mode is not PoolingMode.NONE).
feature_requires_grad (torch.Tensor, optional): An optional
tensor for checking if `per_sample_weights` requires gradient
Returns:
A 2D-tensor containing looked up data. Shape `(B, total_D)`
where `B` = batch size and `total_D` = the sum of all
embedding dimensions in the table
Example:
>>> import torch
>>>
>>> from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
>>> EmbeddingLocation,
>>> )
>>> from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
>>> SplitTableBatchedEmbeddingBagsCodegen,
>>> ComputeDevice,
>>> )
>>>
>>> # Two tables
>>> embedding_specs = [
>>> (3, 8, EmbeddingLocation.DEVICE, ComputeDevice.CUDA),
>>> (5, 4, EmbeddingLocation.MANAGED, ComputeDevice.CUDA)
>>> ]
>>>
>>> tbe = SplitTableBatchedEmbeddingBagsCodegen(embedding_specs)
>>> tbe.init_embedding_weights_uniform(-1, 1)
>>>
>>> print(tbe.split_embedding_weights())
[tensor([[-0.9426, 0.7046, 0.4214, -0.0419, 0.1331, -0.7856, -0.8124, -0.2021],
[-0.5771, 0.5911, -0.7792, -0.1068, -0.6203, 0.4813, -0.1677, 0.4790],
[-0.5587, -0.0941, 0.5754, 0.3475, -0.8952, -0.1964, 0.0810, -0.4174]],
device='cuda:0'), tensor([[-0.2513, -0.4039, -0.3775, 0.3273],
[-0.5399, -0.0229, -0.1455, -0.8770],
[-0.9520, 0.4593, -0.7169, 0.6307],
[-0.1765, 0.8757, 0.8614, 0.2051],
[-0.0603, -0.9980, -0.7958, -0.5826]], device='cuda:0')]
>>> # Batch size = 3
>>> indices = torch.tensor([0, 1, 2, 0, 1, 2, 0, 3, 1, 4, 2, 0, 0],
>>> device="cuda",
>>> dtype=torch.long)
>>> offsets = torch.tensor([0, 2, 5, 7, 9, 12, 13],
>>> device="cuda",
>>> dtype=torch.long)
>>>
>>> output = tbe(indices, offsets)
>>>
>>> # Batch size = 3, total embedding dimension = 12
>>> print(output.shape)
torch.Size([3, 12])
>>> print(output)
tensor([[-1.5197, 1.2957, -0.3578, -0.1487, -0.4873, -0.3044, -0.9801, 0.2769,
-0.7164, 0.8528, 0.7159, -0.6719],
[-2.0784, 1.2016, 0.2176, 0.1988, -1.3825, -0.5008, -0.8991, -0.1405,
-1.2637, -0.9427, -1.8902, 0.3754],
[-1.5013, 0.6105, 0.9968, 0.3057, -0.7621, -0.9821, -0.7314, -0.6195,
-0.2513, -0.4039, -0.3775, 0.3273]], device='cuda:0',
grad_fn=<CppNode<SplitLookupFunction_sgd_Op>>)
"""

embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]
Expand Down

0 comments on commit 63cd1fa

Please sign in to comment.