From 7fe41a54f8b1d626fde114ecd4df8b1b079fb562 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 13 Apr 2023 16:52:52 +0200 Subject: [PATCH 1/3] use parent init_component --- domainlab/models/model_hduva.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/domainlab/models/model_hduva.py b/domainlab/models/model_hduva.py index ab735e3c4..7324a96bb 100644 --- a/domainlab/models/model_hduva.py +++ b/domainlab/models/model_hduva.py @@ -52,29 +52,12 @@ def __init__(self, chain_node_builder, super().__init__(chain_node_builder, zd_dim, zy_dim, zx_dim, list_str_y, list_d_tr) - + # beta_d:y:x:t initialized via @store_args # topic to zd follows Gaussian distribution self.add_module("net_p_zd", self.chain_node_builder.construct_cond_prior( self.topic_dim, self.zd_dim)) - def _init_components(self): - """ - q(z|x) - p(zy) - q_{classif}(zy) - """ - self.add_module("encoder", self.chain_node_builder.build_encoder( - self.device, self.topic_dim)) - self.add_module("decoder", self.chain_node_builder.build_decoder( - self.topic_dim)) - self.add_module("net_p_zy", - self.chain_node_builder.construct_cond_prior( - self.dim_y, self.zy_dim)) - self.add_module("net_classif_y", - self.chain_node_builder.construct_classifier( - self.zy_dim, self.dim_y)) - def init_p_topic_batch(self, batch_size, device): """ flat prior From 27ba0cae29b1de6480df74630864fcbe9f5b253d Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 20 Apr 2023 14:52:31 +0200 Subject: [PATCH 2/3] remove y in interface vae xyd --- domainlab/models/interface_vae_xyd.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/domainlab/models/interface_vae_xyd.py b/domainlab/models/interface_vae_xyd.py index adf101b54..95794cad7 100644 --- a/domainlab/models/interface_vae_xyd.py +++ b/domainlab/models/interface_vae_xyd.py @@ -27,9 +27,6 @@ def _init_components(self): """ self.add_module("encoder", self.chain_node_builder.build_encoder()) self.add_module("decoder", self.chain_node_builder.build_decoder()) - self.add_module("net_p_zy", - self.chain_node_builder.construct_cond_prior( - self.dim_y, self.zy_dim)) def init_p_zx4batch(self, batch_size, device): """ From 445321f88218a9b0d2a38d059176d2d17c6ff3fc Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 20 Apr 2023 15:44:26 +0200 Subject: [PATCH 3/3] add net_p_zy --- domainlab/models/model_vae_xyd_classif.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/domainlab/models/model_vae_xyd_classif.py b/domainlab/models/model_vae_xyd_classif.py index 577e9e718..9a3dc177f 100644 --- a/domainlab/models/model_vae_xyd_classif.py +++ b/domainlab/models/model_vae_xyd_classif.py @@ -30,6 +30,8 @@ def cal_logit_y(self, tensor_x): def _init_components(self): super()._init_components() + self.add_module("net_p_zy", + self.chain_node_builder.construct_cond_prior(self.dim_y, self.zy_dim)) self.add_module("net_classif_y", self.chain_node_builder.construct_classifier( self.zy_dim, self.dim_y))