From 827f90b9e003ac770f98b9eb49b94ebab78b267f Mon Sep 17 00:00:00 2001 From: eknag Date: Fri, 29 Sep 2023 12:34:43 -0700 Subject: [PATCH] =?UTF-8?q?changed=20to=20proper=20Xavier=20initialization?= =?UTF-8?q?,=20existing=20implementation=20was=20=E2=80=A6=20(#1927)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: …resulting in a large negative bias, which was killing all gradients through the following relu. https://paperswithcode.com/method/xavier-initialization Pull Request resolved: https://github.com/pytorch/benchmark/pull/1927 Reviewed By: davidberard98 Differential Revision: D49754019 Pulled By: xuzhao9 fbshipit-source-id: 436676afed9bcc0f464cd1b25465444a98a52b5a --- torchbenchmark/models/dlrm/dlrm_s_pytorch.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchbenchmark/models/dlrm/dlrm_s_pytorch.py b/torchbenchmark/models/dlrm/dlrm_s_pytorch.py index 584c7a53b2..59b24045d3 100644 --- a/torchbenchmark/models/dlrm/dlrm_s_pytorch.py +++ b/torchbenchmark/models/dlrm/dlrm_s_pytorch.py @@ -149,8 +149,7 @@ def create_mlp(self, ln, sigmoid_layer): mean = 0.0 # std_dev = np.sqrt(variance) std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n) W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32) - std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1)) - bt = np.random.normal(mean, std_dev, size=m).astype(np.float32) + bt = np.zeros(m).astype(np.float32) # see upstream PR at https://github.com/facebookresearch/dlrm/pull/358 # approach 1 LL.weight.data = torch.tensor(W, requires_grad=True) LL.bias.data = torch.tensor(bt, requires_grad=True)