Skip to content

Commit

Permalink
Redefine n_factors as the number of uninformed factors (patched)
Browse files Browse the repository at this point in the history
  • Loading branch information
arberqoku committed May 30, 2024
1 parent eb9e6a8 commit cdb4fa8
Showing 1 changed file with 18 additions and 34 deletions.
52 changes: 18 additions & 34 deletions muvi/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
observations: MultiView,
prior_masks: Optional[MultiView] = None,
covariates: Optional[SingleView] = None,
prior_confidence: Optional[Union[float, str]] = "med",
prior_confidence: Optional[Union[float, str]] = "low",
n_factors: Optional[int] = None,
view_names: Optional[list[str]] = None,
likelihoods: Optional[Union[dict[str, str], list[str]]] = None,
Expand All @@ -74,10 +74,9 @@ def __init__(
typical values are 'low' (0.99), 'med' (0.995) and 'high' (0.999),
by default 'med'
n_factors : int, optional
Number of latent factors,
Number of the uninformed latent factors,
can be omitted when providing prior masks,
or it can be used to introduce additional dense factors
if larger than the informed factors,
or it can be used to introduce additional dense factors,
by default None
view_names : list[str], optional
List of names for each view,
Expand Down Expand Up @@ -471,6 +470,7 @@ def _setup_prior_masks(self, masks, n_factors):

if not informed:
self.n_factors = n_factors
self.n_dense_factors = n_factors
# TODO: duplicate line...see below
self.factor_names = pd.Index([f"factor_{k}" for k in range(n_factors)])
return None, None
Expand All @@ -484,27 +484,12 @@ def _setup_prior_masks(self, masks, n_factors):
informed_views = [vn for vn in self.view_names if vn in masks]

n_prior_factors = masks[informed_views[0]].shape[0]

if n_factors is None:
n_factors = n_prior_factors
n_factors = 0

if n_prior_factors > n_factors:
logger.warning(
f"Prior mask informs more factors ({n_prior_factors}) "
f"than the pre-defined `n_factors` ({n_factors}). "
f"Updating `n_factors` to {n_prior_factors}."
)
n_factors = n_prior_factors

n_dense_factors = 0
if n_prior_factors < n_factors:
logger.warning(
f"Prior mask informs fewer factors ({n_prior_factors}) "
f"than the pre-defined `n_factors` ({n_factors}). "
f"Informing only the first {n_prior_factors} factors, "
"the rest remains uninformed."
)
# extend all prior masks with additional uninformed factors
n_dense_factors = n_factors - n_prior_factors
n_dense_factors = n_factors
n_factors += n_prior_factors

factor_names = None
for vn in self.view_names:
Expand Down Expand Up @@ -555,7 +540,10 @@ def _setup_prior_masks(self, masks, n_factors):
[
view_mask,
np.zeros(
(n_factors, n_features_view - n_features_mask)
(
view_mask.shape[0],
n_features_view - n_features_mask,
)
),
],
axis=1,
Expand Down Expand Up @@ -647,7 +635,8 @@ def _setup_prior_masks(self, masks, n_factors):
if n_dense_factors > 0:
prior_masks = {
vn: np.concatenate(
[vm, np.ones((n_dense_factors, self.n_features[vn])).astype(bool)]
[vm, np.ones((n_dense_factors, self.n_features[vn])).astype(bool)],
axis=0,
)
for vn, vm in masks.items()
}
Expand Down Expand Up @@ -1205,13 +1194,13 @@ def get_covariates(
as_df=as_df,
)

def _setup_model_guide(self, batch_size: int, scale_elbo: bool):
def _setup_model_guide(self, scale_elbo: bool):
"""Setup model and guide.
Parameters
----------
batch_size : int
Batch size when subsampling
scale_elbo : bool, optional
Whether to scale the ELBO across views, by default True
Returns
-------
Expand All @@ -1231,7 +1220,6 @@ def _setup_model_guide(self, batch_size: int, scale_elbo: bool):

self._model = MuVIModel(
self.n_samples,
n_subsamples=batch_size,
n_features=[self.n_features[vn] for vn in self.view_names],
n_factors=self.n_factors,
prior_scales=prior_scales,
Expand Down Expand Up @@ -1408,7 +1396,7 @@ def fit(
if n_particles > 1:
logger.info(f"Using {n_particles} particles in parallel.")
logger.info("Preparing model and guide...")
self._setup_model_guide(batch_size, scale_elbo)
self._setup_model_guide(scale_elbo)
logger.info("Preparing optimizer...")
opt = self._setup_optimizer(batch_size, n_epochs, learning_rate, optimizer)
logger.info("Preparing SVI...")
Expand Down Expand Up @@ -1520,7 +1508,6 @@ class MuVIModel(PyroModule):
def __init__(
self,
n_samples: int,
n_subsamples: int,
n_features: list[int],
n_factors: int,
prior_scales: Optional[list[torch.Tensor]],
Expand All @@ -1539,8 +1526,6 @@ def __init__(
----------
n_samples : int
Number of samples
n_subsamples : int
Number of subsamples (batch size)
n_features : list[int]
Number of features as list for each view
n_factors : int
Expand Down Expand Up @@ -1568,7 +1553,6 @@ def __init__(
"""
super().__init__(name="MuVIModel")
self.n_samples = n_samples
self.n_subsamples = n_subsamples
self.n_features = n_features
self.feature_offsets = [0, *np.cumsum(self.n_features).tolist()]
self.n_views = len(self.n_features)
Expand Down

0 comments on commit cdb4fa8

Please sign in to comment.