From a139c02b498d789ad20f8b91a5aab48d0d027d2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=AB=E8=8B=8F?= Date: Wed, 4 Dec 2024 14:27:07 +0800 Subject: [PATCH] modify dropoutnet in case of batch size mismatch --- easy_rec/python/model/dropoutnet.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/easy_rec/python/model/dropoutnet.py b/easy_rec/python/model/dropoutnet.py index 03607a361..b4c30e77c 100644 --- a/easy_rec/python/model/dropoutnet.py +++ b/easy_rec/python/model/dropoutnet.py @@ -21,8 +21,8 @@ def cosine_similarity(user_emb, item_emb): tf.multiply(user_emb, item_emb), axis=1, name='cosine') return user_item_sim -def bernoulli_dropout(x, rate=0.5): - if rate == 0.0: +def bernoulli_dropout(x, rate, training=False): + if rate == 0.0 or not training: return x keep_rate = 1.0 - rate dist = tf.distributions.Bernoulli(probs=keep_rate, dtype=x.dtype) @@ -90,12 +90,9 @@ def build_predict_graph(self): content_feature = user_content_dnn(self.user_content_feature) user_features.append(content_feature) if self.user_preference_feature is not None: - if self._is_training: - user_prefer_feature = bernoulli_dropout(self.user_preference_feature, - self._model_config.user_dropout_rate) - else: - user_prefer_feature = self.user_preference_feature - + user_prefer_feature = bernoulli_dropout(self.user_preference_feature, + self._model_config.user_dropout_rate, + self._is_training) user_prefer_dnn = dnn.DNN(self.user_preference_layers, self._l2_reg, 'user_preference', self._is_training) prefer_feature = user_prefer_dnn(user_prefer_feature) @@ -121,12 +118,9 @@ def build_predict_graph(self): content_feature = item_content_dnn(self.item_content_feature) item_features.append(content_feature) if self.item_preference_feature is not None: - if self._is_training: - item_prefer_feature = bernoulli_dropout(self.item_preference_feature, - self._model_config.item_dropout_rate) - else: - item_prefer_feature = self.item_preference_feature - + item_prefer_feature = bernoulli_dropout(self.item_preference_feature, + self._model_config.item_dropout_rate, + self._is_training) item_prefer_dnn = dnn.DNN(self.item_preference_layers, self._l2_reg, 'item_preference', self._is_training) prefer_feature = item_prefer_dnn(item_prefer_feature)