diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 0ae295a6c0..c06e5aa51b 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -88,7 +88,7 @@ def permute_2D_sparse_data_meta( if permuted_lengths_sum is not None: permuted_indices_size = permuted_lengths_sum else: - ctx = torch._custom_op.impl.get_ctx() + ctx = torch.library.get_ctx() permuted_indices_size = ctx.new_dynamic_size() # pyre-fixme permuted_indices = indices.new_empty(permuted_indices_size) @@ -114,7 +114,7 @@ def permute_1D_sparse_data_meta( if permuted_lengths_sum is not None: permuted_indices_size = permuted_lengths_sum else: - ctx = torch._custom_op.impl.get_ctx() + ctx = torch.library.get_ctx() permuted_indices_size = ctx.new_dynamic_size() # pyre-fixme permuted_indices = indices.new_empty(permuted_indices_size)