diff --git a/makani/networks/afnonet_v2.py b/makani/networks/afnonet_v2.py index 3b22627..ccff5e3 100644 --- a/makani/networks/afnonet_v2.py +++ b/makani/networks/afnonet_v2.py @@ -22,7 +22,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from apex.normalization import FusedLayerNorm import torch.fft from torch.nn.modules.container import Sequential from torch.utils.checkpoint import checkpoint_sequential @@ -282,7 +281,7 @@ def _init_weights(self, m): # nn.init.normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm) or isinstance(m, FusedLayerNorm) or isinstance(m, nn.InstanceNorm3d): + elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.InstanceNorm3d): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)