Skip to content

Commit

Permalink
fixed the data overfitting issue
Browse files Browse the repository at this point in the history
  • Loading branch information
shivendrra committed Mar 29, 2024
1 parent 9662211 commit 0bad9f6
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,13 @@ def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
self.norm = RMSNorm(d_model, eps=norm_eps)

def forward(self, src):
src_att = self.s_att(self.norm(src))
src_out = src + self.dropout(src_att)
src = self.norm(src)
src_out = src + self.dropout(self.s_att(src))

src = self.ffwd(self.norm(src_out))
src_f = src_out + self.dropout(src)
src = self.norm(src_out)
src_f = src + self.dropout(self.ffwd(src))

del src_att, src_out, src
del src_out, src
return src_f

class DecoderNetwork(nn.Module):
Expand All @@ -201,14 +201,14 @@ def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
self.norm = RMSNorm(d_model, eps=norm_eps)

def forward(self, src, att):
m_att_out = self.m_att(self.norm(src))
m_out = src + self.dropout(m_att_out)
m_att_out = self.norm(src)
m_out = src + self.dropout(self.m_att(m_att_out))

f_out = self.f_att(self.norm(m_out), self.norm(att))
f_out = self.f_att(m_out, self.norm(att))
f_out = m_out + self.dropout(f_out)

src_f = self.ffwd(self.norm(f_out))
src_f = f_out + self.dropout(src_f)
src_f = self.norm(f_out)
src_f = f_out + self.dropout(self.ffwd(src_f))

del f_out, m_out, m_att_out, src, att
return src_f
Expand Down

0 comments on commit 0bad9f6

Please sign in to comment.