diff --git a/muvi/core/models.py b/muvi/core/models.py index d328002..a6ab271 100755 --- a/muvi/core/models.py +++ b/muvi/core/models.py @@ -48,7 +48,7 @@ def __init__( observations: MultiView, prior_masks: Optional[MultiView] = None, covariates: Optional[SingleView] = None, - prior_confidence: Optional[Union[float, str]] = "med", + prior_confidence: Optional[Union[float, str]] = "low", n_factors: Optional[int] = None, view_names: Optional[list[str]] = None, likelihoods: Optional[Union[dict[str, str], list[str]]] = None, @@ -74,10 +74,9 @@ def __init__( typical values are 'low' (0.99), 'med' (0.995) and 'high' (0.999), by default 'med' n_factors : int, optional - Number of latent factors, + Number of the uninformed latent factors, can be omitted when providing prior masks, - or it can be used to introduce additional dense factors - if larger than the informed factors, + or it can be used to introduce additional dense factors, by default None view_names : list[str], optional List of names for each view, @@ -471,6 +470,7 @@ def _setup_prior_masks(self, masks, n_factors): if not informed: self.n_factors = n_factors + self.n_dense_factors = n_factors # TODO: duplicate line...see below self.factor_names = pd.Index([f"factor_{k}" for k in range(n_factors)]) return None, None @@ -484,27 +484,12 @@ def _setup_prior_masks(self, masks, n_factors): informed_views = [vn for vn in self.view_names if vn in masks] n_prior_factors = masks[informed_views[0]].shape[0] + if n_factors is None: - n_factors = n_prior_factors + n_factors = 0 - if n_prior_factors > n_factors: - logger.warning( - f"Prior mask informs more factors ({n_prior_factors}) " - f"than the pre-defined `n_factors` ({n_factors}). " - f"Updating `n_factors` to {n_prior_factors}." - ) - n_factors = n_prior_factors - - n_dense_factors = 0 - if n_prior_factors < n_factors: - logger.warning( - f"Prior mask informs fewer factors ({n_prior_factors}) " - f"than the pre-defined `n_factors` ({n_factors}). " - f"Informing only the first {n_prior_factors} factors, " - "the rest remains uninformed." - ) - # extend all prior masks with additional uninformed factors - n_dense_factors = n_factors - n_prior_factors + n_dense_factors = n_factors + n_factors += n_prior_factors factor_names = None for vn in self.view_names: @@ -555,7 +540,10 @@ def _setup_prior_masks(self, masks, n_factors): [ view_mask, np.zeros( - (n_factors, n_features_view - n_features_mask) + ( + view_mask.shape[0], + n_features_view - n_features_mask, + ) ), ], axis=1, @@ -647,7 +635,8 @@ def _setup_prior_masks(self, masks, n_factors): if n_dense_factors > 0: prior_masks = { vn: np.concatenate( - [vm, np.ones((n_dense_factors, self.n_features[vn])).astype(bool)] + [vm, np.ones((n_dense_factors, self.n_features[vn])).astype(bool)], + axis=0, ) for vn, vm in masks.items() } @@ -1205,13 +1194,13 @@ def get_covariates( as_df=as_df, ) - def _setup_model_guide(self, batch_size: int, scale_elbo: bool): + def _setup_model_guide(self, scale_elbo: bool): """Setup model and guide. Parameters ---------- - batch_size : int - Batch size when subsampling + scale_elbo : bool, optional + Whether to scale the ELBO across views, by default True Returns ------- @@ -1231,7 +1220,6 @@ def _setup_model_guide(self, batch_size: int, scale_elbo: bool): self._model = MuVIModel( self.n_samples, - n_subsamples=batch_size, n_features=[self.n_features[vn] for vn in self.view_names], n_factors=self.n_factors, prior_scales=prior_scales, @@ -1408,7 +1396,7 @@ def fit( if n_particles > 1: logger.info(f"Using {n_particles} particles in parallel.") logger.info("Preparing model and guide...") - self._setup_model_guide(batch_size, scale_elbo) + self._setup_model_guide(scale_elbo) logger.info("Preparing optimizer...") opt = self._setup_optimizer(batch_size, n_epochs, learning_rate, optimizer) logger.info("Preparing SVI...") @@ -1520,7 +1508,6 @@ class MuVIModel(PyroModule): def __init__( self, n_samples: int, - n_subsamples: int, n_features: list[int], n_factors: int, prior_scales: Optional[list[torch.Tensor]], @@ -1539,8 +1526,6 @@ def __init__( ---------- n_samples : int Number of samples - n_subsamples : int - Number of subsamples (batch size) n_features : list[int] Number of features as list for each view n_factors : int @@ -1568,7 +1553,6 @@ def __init__( """ super().__init__(name="MuVIModel") self.n_samples = n_samples - self.n_subsamples = n_subsamples self.n_features = n_features self.feature_offsets = [0, *np.cumsum(self.n_features).tolist()] self.n_views = len(self.n_features)