diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index e624729db..31331a4e8 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -287,7 +287,10 @@ def create_sharding_infos_by_sharding( embedding_names=embedding_names, weight_init_max=config.weight_init_max, weight_init_min=config.weight_init_min, - num_embeddings_post_pruning=config.num_embeddings_post_pruning, + num_embeddings_post_pruning=( + getattr(config, "num_embeddings_post_pruning", None) + # TODO: Need to check if attribute exists for BC + ), ), param_sharding=parameter_sharding, param=param, @@ -402,7 +405,10 @@ def create_sharding_infos_by_sharding_device_group( embedding_names=embedding_names, weight_init_max=config.weight_init_max, weight_init_min=config.weight_init_min, - num_embeddings_post_pruning=config.num_embeddings_post_pruning, + num_embeddings_post_pruning=( + getattr(config, "num_embeddings_post_pruning", None) + # TODO: Need to check if attribute exists for BC + ), ), param_sharding=parameter_sharding, param=param, diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 24a118671..91fa80581 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -381,9 +381,11 @@ def __init__( ( table.name, ( - table.num_embeddings - if table.num_embeddings_post_pruning is None - else table.num_embeddings_post_pruning + table.num_embeddings_post_pruning + # TODO: Need to check if attribute exists for BC + if getattr(table, "num_embeddings_post_pruning", None) + is not None + else table.num_embeddings ), table.embedding_dim, data_type_to_sparse_type(data_type),