Skip to content

Commit

Permalink
modify dropoutnet in case of batch size mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Dec 4, 2024
1 parent 501139d commit a139c02
Showing 1 changed file with 8 additions and 14 deletions.
22 changes: 8 additions & 14 deletions easy_rec/python/model/dropoutnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def cosine_similarity(user_emb, item_emb):
tf.multiply(user_emb, item_emb), axis=1, name='cosine')
return user_item_sim

def bernoulli_dropout(x, rate=0.5):
if rate == 0.0:
def bernoulli_dropout(x, rate, training=False):
if rate == 0.0 or not training:
return x
keep_rate = 1.0 - rate
dist = tf.distributions.Bernoulli(probs=keep_rate, dtype=x.dtype)
Expand Down Expand Up @@ -90,12 +90,9 @@ def build_predict_graph(self):
content_feature = user_content_dnn(self.user_content_feature)
user_features.append(content_feature)
if self.user_preference_feature is not None:
if self._is_training:
user_prefer_feature = bernoulli_dropout(self.user_preference_feature,
self._model_config.user_dropout_rate)
else:
user_prefer_feature = self.user_preference_feature

user_prefer_feature = bernoulli_dropout(self.user_preference_feature,
self._model_config.user_dropout_rate,
self._is_training)
user_prefer_dnn = dnn.DNN(self.user_preference_layers, self._l2_reg,
'user_preference', self._is_training)
prefer_feature = user_prefer_dnn(user_prefer_feature)
Expand All @@ -121,12 +118,9 @@ def build_predict_graph(self):
content_feature = item_content_dnn(self.item_content_feature)
item_features.append(content_feature)
if self.item_preference_feature is not None:
if self._is_training:
item_prefer_feature = bernoulli_dropout(self.item_preference_feature,
self._model_config.item_dropout_rate)
else:
item_prefer_feature = self.item_preference_feature

item_prefer_feature = bernoulli_dropout(self.item_preference_feature,
self._model_config.item_dropout_rate,
self._is_training)
item_prefer_dnn = dnn.DNN(self.item_preference_layers, self._l2_reg,
'item_preference', self._is_training)
prefer_feature = item_prefer_dnn(item_prefer_feature)
Expand Down

0 comments on commit a139c02

Please sign in to comment.