From 883dc7b6886388b1224d5858e0455b58f0fc6507 Mon Sep 17 00:00:00 2001 From: EIFY Date: Tue, 26 Nov 2024 14:20:16 -0800 Subject: [PATCH] fix pytorch_default_init() torch.nn.init.trunc_normal_() defaults to truncation at (a, b), not (a * std, b * std). --- algorithmic_efficiency/init_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/init_utils.py b/algorithmic_efficiency/init_utils.py index 66ed041ce..185480cc7 100644 --- a/algorithmic_efficiency/init_utils.py +++ b/algorithmic_efficiency/init_utils.py @@ -13,6 +13,6 @@ def pytorch_default_init(module: nn.Module) -> None: # Perform lecun_normal initialization. fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) std = math.sqrt(1. / fan_in) / .87962566103423978 - nn.init.trunc_normal_(module.weight, std=std) + nn.init.trunc_normal_(module.weight, std=std, a=-2 * std, b=2 * std) if module.bias is not None: nn.init.constant_(module.bias, 0.)