diff --git a/base/model.py b/base/model.py index 83ee06d..a671e71 100644 --- a/base/model.py +++ b/base/model.py @@ -182,13 +182,13 @@ def __init__(self, d_model, n_head, norm_eps, dropout, block_size): self.norm = RMSNorm(d_model, eps=norm_eps) def forward(self, src): - src_att = self.s_att(self.norm(src)) - src_out = src + self.dropout(src_att) + src = self.norm(src) + src_out = src + self.dropout(self.s_att(src)) - src = self.ffwd(self.norm(src_out)) - src_f = src_out + self.dropout(src) + src = self.norm(src_out) + src_f = src + self.dropout(self.ffwd(src)) - del src_att, src_out, src + del src_out, src return src_f class DecoderNetwork(nn.Module): @@ -201,14 +201,14 @@ def __init__(self, d_model, n_head, norm_eps, dropout, block_size): self.norm = RMSNorm(d_model, eps=norm_eps) def forward(self, src, att): - m_att_out = self.m_att(self.norm(src)) - m_out = src + self.dropout(m_att_out) + m_att_out = self.norm(src) + m_out = src + self.dropout(self.m_att(m_att_out)) - f_out = self.f_att(self.norm(m_out), self.norm(att)) + f_out = self.f_att(m_out, self.norm(att)) f_out = m_out + self.dropout(f_out) - src_f = self.ffwd(self.norm(f_out)) - src_f = f_out + self.dropout(src_f) + src_f = self.norm(f_out) + src_f = f_out + self.dropout(self.ffwd(src_f)) del f_out, m_out, m_att_out, src, att return src_f