Skip to content

Commit

Permalink
Fix IEN cogwheel test with new pruning attribute (pytorch#2364)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2364

https://www.internalfb.com/intern/test/281475139080930/ broke as a result of adding num_embeddings_post_processing. This change resolves this BC issue in the tests without reverting.

Reviewed By: guangyuwang

Differential Revision: D62225812

fbshipit-source-id: 4e718faccc9f24ee12e67d27b349d8ab47e54825
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Sep 5, 2024
1 parent ea0f4d0 commit 81ae241
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
10 changes: 8 additions & 2 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,10 @@ def create_sharding_infos_by_sharding(
embedding_names=embedding_names,
weight_init_max=config.weight_init_max,
weight_init_min=config.weight_init_min,
num_embeddings_post_pruning=config.num_embeddings_post_pruning,
num_embeddings_post_pruning=(
getattr(config, "num_embeddings_post_pruning", None)
# TODO: Need to check if attribute exists for BC
),
),
param_sharding=parameter_sharding,
param=param,
Expand Down Expand Up @@ -402,7 +405,10 @@ def create_sharding_infos_by_sharding_device_group(
embedding_names=embedding_names,
weight_init_max=config.weight_init_max,
weight_init_min=config.weight_init_min,
num_embeddings_post_pruning=config.num_embeddings_post_pruning,
num_embeddings_post_pruning=(
getattr(config, "num_embeddings_post_pruning", None)
# TODO: Need to check if attribute exists for BC
),
),
param_sharding=parameter_sharding,
param=param,
Expand Down
8 changes: 5 additions & 3 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,11 @@ def __init__(
(
table.name,
(
table.num_embeddings
if table.num_embeddings_post_pruning is None
else table.num_embeddings_post_pruning
table.num_embeddings_post_pruning
# TODO: Need to check if attribute exists for BC
if getattr(table, "num_embeddings_post_pruning", None)
is not None
else table.num_embeddings
),
table.embedding_dim,
data_type_to_sparse_type(data_type),
Expand Down

0 comments on commit 81ae241

Please sign in to comment.