Skip to content

Commit

Permalink
Use KTRegroupAsDict to Replace KeyedTensor.regroup_as_dict in EBC_spa…
Browse files Browse the repository at this point in the history
…rse_arch (pytorch#2272)

Summary:
Pull Request resolved: pytorch#2272

# context
* Currently in APF a class method `KeyedTensor.regroup_as_dict` is used for permuting and regrouping the pooled embeddings
* This function is not very efficent in training because every time it needs to calculate the necessary metadata arguments for the fbgemm operator. Moreover, it also needs to do a host-to-device data transfer to move the metadata tensor to GPU: [codepointer](https://fburl.com/code/fmhkg6sn)
```
    # // metadata calculation for permute, offsets, etc. all are tensors
    permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups(
        keys, lengths, groups
    )
    values = torch.concat(values, dim=1)
    device = values.device
    permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad(
        values,
        _pin_and_move(offsets, device), # // needs to use pinned_memory and
        _pin_and_move(permute, device), # // initiate several H2D transfers
        _pin_and_move(inv_offsets, device),
        _pin_and_move(inv_permute, device),
    )
```
* However, these metadata (Tensors) won't change during training, so we can actually cache the results for the first batch and re-use them.
* The recommended usage is `KTRegroupAsDict`, as a module for permuting and regrouping work. This module can store these metadata tensors as instance variables in the first run of the forward pass, then re-used them directly afterwards.

NOTE: This `KTRegroupAsDict` module is also IR-compatible with custom-op approach in D57578012

# numbers with IG FM model
* [baseline trace](https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree%2Ftraces%2Fdynocli%2Faps-ig_ctr_aps_vanilla_baseline-a6e954af60%2F1%2Frank-0.Aug_04_00_01_16.3621.pt.trace.json.gz&bucket=aps_traces), [experimental trace](https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree%2Ftraces%2Fdynocli%2Faps-ig_ctr_aps_vanilla_both_diffs-3503324cb2%2F4%2Frank-0.Aug_03_20_16_01.3608.pt.trace.json.gz&bucket=aps_traces)
* the GPU runtime saving is about 1ms, which is consistent with the benchmark results. the overall duration (per batch) is 200ms. so the improvement is about 0.5%
 {F1792260523}
* cpu runtime saving is 1.3ms
 {F1792260686}

# additional notes
1. we don't expect any results change, these two approaches should produce exactly the same results.
2. regroup_as_dict works on pooled embeddings, which are the results from embedding table lookup. so for each feature, the length should be constant as it's defined in the config.

Reviewed By: really121

Differential Revision: D43405610
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Aug 10, 2024
1 parent d60400b commit 1d2c5a3
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions torchrec/modules/tests/test_regroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,52 @@ def setUp(self) -> None:
self.keys = ["user", "object"]
self.labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float()

def new_kts(self) -> None:
self.kts = build_kts(
dense_features=20,
sparse_features=20,
dim_dense=64,
dim_sparse=128,
batch_size=128,
device=torch.device("cpu"),
run_backward=True,
)

def test_regroup_backward_skips_and_duplicates(self) -> None:
groups = build_groups(
kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True
)
assert _all_keys_used_once(self.kts, groups) is False

regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys)

# first run
tensor_groups = regroup_module(self.kts)
pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1))
loss = torch.nn.functional.l1_loss(pred0, self.labels).sum()
actual_kt_0_grad, actual_kt_1_grad = torch.autograd.grad(
loss, [self.kts[0].values(), self.kts[1].values()]
)

# clear grads so can reuse inputs
self.kts[0].values().grad = None
self.kts[1].values().grad = None

tensor_groups = KeyedTensor.regroup_as_dict(
keyed_tensors=self.kts, groups=groups, keys=self.keys
)
pred1 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1))
loss = torch.nn.functional.l1_loss(pred1, self.labels).sum()
expected_kt_0_grad, expected_kt_1_grad = torch.autograd.grad(
loss, [self.kts[0].values(), self.kts[1].values()]
)

torch.allclose(pred0, pred1)
torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)

# second run
self.new_kts()
tensor_groups = regroup_module(self.kts)
pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1))
loss = torch.nn.functional.l1_loss(pred0, self.labels).sum()
Expand Down

0 comments on commit 1d2c5a3

Please sign in to comment.