Skip to content

Commit

Permalink
add dense TBE to template and enable VBE support (pytorch#2620)
Browse files Browse the repository at this point in the history
Summary:

- make the dense TBE headers into a template
- add VBE options to the dense TBE header

Differential Revision: D57017981
  • Loading branch information
joshuadeng authored and facebook-github-bot committed May 28, 2024
1 parent 9f1d0ef commit b2e5c16
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 498 deletions.
15 changes: 13 additions & 2 deletions fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def render_backward_templates(
):
if nobag and (weighted or vbe):
continue
if kwargs.get("dense") and (vbe or ssd):
if kwargs.get("dense") and ssd:
continue
if ssd and (vbe or is_gwd):
continue
Expand Down Expand Up @@ -152,7 +152,7 @@ def generate_backward_split_gpu(**kwargs: Any) -> None:
**kwargs,
)

# Generate the backward splits (non-dense)
# Generate the backward splits
# We generate only the API to preserve the backward compatibility if
# has_gpu_support=True
if not kwargs.get("dense"):
Expand Down Expand Up @@ -197,6 +197,17 @@ def generate_backward_split_gpu(**kwargs: Any) -> None:
filename, is_fbcode=args.is_fbcode, ssd=ssd, **kwargs
)

else:
template_filepath = (
"training/backward/embedding_backward_split_host_template.cpp"
)
filename = "gen_embedding_backward_split_dense.cpp"
CodeTemplate.load(template_filepath).write(
filename,
is_forward=False,
**kwargs,
)

@staticmethod
def generate_backward_split_cpu(**kwargs: Any) -> None:
"""
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/codegen/genscript/generate_forward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def render_forward_templates(
):
if nobag and (weighted or vbe):
continue
if dense and (vbe or ssd):
if dense and ssd:
continue
if ssd and (vbe or is_gwd):
continue
Expand All @@ -64,6 +64,7 @@ def render_forward_templates(
is_gwd=is_gwd,
)

@staticmethod
def generate_pt2_wrappers() -> None:
# Generate PT2 forward wrapper (CUDA)
CodeTemplate.load(
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def dense() -> Dict[str, Any]:
),
"has_cpu_support": True,
"has_gpu_support": True,
"has_vbe_support": False,
"has_vbe_support": True,
"has_global_weight_decay_support": False,
"has_ssd_support": False,
}
Expand Down
Loading

0 comments on commit b2e5c16

Please sign in to comment.