diff --git a/easy_rec/python/model/dropoutnet.py b/easy_rec/python/model/dropoutnet.py index 03607a361..b4c30e77c 100644 --- a/easy_rec/python/model/dropoutnet.py +++ b/easy_rec/python/model/dropoutnet.py @@ -21,8 +21,8 @@ def cosine_similarity(user_emb, item_emb): tf.multiply(user_emb, item_emb), axis=1, name='cosine') return user_item_sim -def bernoulli_dropout(x, rate=0.5): - if rate == 0.0: +def bernoulli_dropout(x, rate, training=False): + if rate == 0.0 or not training: return x keep_rate = 1.0 - rate dist = tf.distributions.Bernoulli(probs=keep_rate, dtype=x.dtype) @@ -90,12 +90,9 @@ def build_predict_graph(self): content_feature = user_content_dnn(self.user_content_feature) user_features.append(content_feature) if self.user_preference_feature is not None: - if self._is_training: - user_prefer_feature = bernoulli_dropout(self.user_preference_feature, - self._model_config.user_dropout_rate) - else: - user_prefer_feature = self.user_preference_feature - + user_prefer_feature = bernoulli_dropout(self.user_preference_feature, + self._model_config.user_dropout_rate, + self._is_training) user_prefer_dnn = dnn.DNN(self.user_preference_layers, self._l2_reg, 'user_preference', self._is_training) prefer_feature = user_prefer_dnn(user_prefer_feature) @@ -121,12 +118,9 @@ def build_predict_graph(self): content_feature = item_content_dnn(self.item_content_feature) item_features.append(content_feature) if self.item_preference_feature is not None: - if self._is_training: - item_prefer_feature = bernoulli_dropout(self.item_preference_feature, - self._model_config.item_dropout_rate) - else: - item_prefer_feature = self.item_preference_feature - + item_prefer_feature = bernoulli_dropout(self.item_preference_feature, + self._model_config.item_dropout_rate, + self._is_training) item_prefer_dnn = dnn.DNN(self.item_preference_layers, self._l2_reg, 'item_preference', self._is_training) prefer_feature = item_prefer_dnn(item_prefer_feature)