diff --git a/domainlab/models/interface_vae_xyd.py b/domainlab/models/interface_vae_xyd.py index adf101b54..95794cad7 100644 --- a/domainlab/models/interface_vae_xyd.py +++ b/domainlab/models/interface_vae_xyd.py @@ -27,9 +27,6 @@ def _init_components(self): """ self.add_module("encoder", self.chain_node_builder.build_encoder()) self.add_module("decoder", self.chain_node_builder.build_decoder()) - self.add_module("net_p_zy", - self.chain_node_builder.construct_cond_prior( - self.dim_y, self.zy_dim)) def init_p_zx4batch(self, batch_size, device): """ diff --git a/domainlab/models/model_hduva.py b/domainlab/models/model_hduva.py index f22df41e1..1a6a2ee04 100644 --- a/domainlab/models/model_hduva.py +++ b/domainlab/models/model_hduva.py @@ -96,30 +96,10 @@ def __init__(self, chain_node_builder, super().__init__(chain_node_builder, zd_dim, zy_dim, zx_dim, list_str_y) - # topic to zd follows Gaussian distribution self.add_module("net_p_zd", self.chain_node_builder.construct_cond_prior( - self.topic_dim, self.zd_dim)) - - # override interface - def _init_components(self): - """ - q(z|x) - p(zy) - q_{classif}(zy) - """ - self.add_module("encoder", self.chain_node_builder.build_encoder( - self.device, self.topic_dim)) - self.add_module("decoder", self.chain_node_builder.build_decoder( - self.topic_dim)) - self.add_module("net_p_zy", - self.chain_node_builder.construct_cond_prior( - self.dim_y, self.zy_dim)) - self.add_module("net_classif_y", - self.chain_node_builder.construct_classifier( - self.zy_dim, self.dim_y)) - self._net_classifier = self.net_classif_y + self.topic_dim, self.zd_dim) def init_p_topic_batch(self, batch_size, device): """ diff --git a/domainlab/models/model_vae_xyd_classif.py b/domainlab/models/model_vae_xyd_classif.py index 0b9917425..af1c73bdf 100644 --- a/domainlab/models/model_vae_xyd_classif.py +++ b/domainlab/models/model_vae_xyd_classif.py @@ -33,6 +33,8 @@ def multiplier4task_loss(self): def _init_components(self): super()._init_components() + self.add_module("net_p_zy", + self.chain_node_builder.construct_cond_prior(self.dim_y, self.zy_dim)) self.add_module("net_classif_y", self.chain_node_builder.construct_classifier( self.zy_dim, self.dim_y))