Skip to content

Commit

Permalink
Redefine n_factors as the number of uninformed factors
Browse files Browse the repository at this point in the history
  • Loading branch information
arberqoku committed May 30, 2024
1 parent f87c633 commit eb9e6a8
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 24 deletions.
12 changes: 10 additions & 2 deletions examples/1_basic_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Using matplotlib backend: <object object at 0x7f4870dffb30>\n"
"Using matplotlib backend: <object object at 0x1767f3b20>\n"
]
}
],
Expand Down Expand Up @@ -79,7 +79,15 @@
"execution_count": 5,
"id": "db044f07",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Torch not compiled with CUDA enabled\n"
]
}
],
"source": [
"device = \"cpu\"\n",
"try:\n",
Expand Down
8 changes: 4 additions & 4 deletions tests/test_alternative_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ def test_from_adata(pandas_input):
mask = pandas_input["masks"][view_idx]
cov = pandas_input["covariates"]

n_factors = mask.shape[0]

adata = ad.AnnData(obs)
adata.obsm["X_np"] = obs.iloc[:, : (obs.shape[1] // 2)].values
adata.obsm["X_pd"] = obs.iloc[:, : (obs.shape[1] // 2)]
Expand All @@ -41,9 +39,11 @@ def test_from_adata(pandas_input):
if obs_key_val == "":
obs_key_val = None

n_factors = 0
prior_mask_key_val = prior_mask_key
if prior_mask_key_val == "":
prior_mask_key_val = None
n_factors = mask.shape[0]

covariate_key_val = covariate_key
if covariate_key_val == "":
Expand All @@ -69,8 +69,6 @@ def test_from_adata(pandas_input):


def test_from_mdata(pandas_input):
n_factors = pandas_input["masks"][0].shape[0]

adata_dict = {}

for view_idx, view_obs in enumerate(pandas_input["observations"]):
Expand Down Expand Up @@ -101,9 +99,11 @@ def test_from_mdata(pandas_input):
if obs_key_val == "":
obs_key_val = None

n_factors = 0
prior_mask_key_val = prior_mask_key
if prior_mask_key_val == "":
prior_mask_key_val = None
n_factors = pandas_input["masks"][0].shape[0]

covariate_key_val = covariate_key
if covariate_key_val == "":
Expand Down
23 changes: 5 additions & 18 deletions tests/test_prior.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,23 @@
from muvi import MuVI


def test_fewer_n_factors(pandas_input):
model = MuVI(
pandas_input["observations"],
pandas_input["masks"],
pandas_input["covariates"],
n_factors=pandas_input["n_factors"] - 2,
view_names=pandas_input["view_names"],
device="cpu",
)

assert model.n_factors == pandas_input["n_factors"]


def test_more_n_factors(pandas_input):
n_dense = 2
n_factors = 2

model = MuVI(
pandas_input["observations"],
pandas_input["masks"],
pandas_input["covariates"],
n_factors=pandas_input["n_factors"] + n_dense,
n_factors=n_factors,
view_names=pandas_input["view_names"],
device="cpu",
)

assert model.n_factors == pandas_input["n_factors"] + n_dense
assert model.n_factors == pandas_input["n_factors"] + n_factors
assert (
model.factor_names[-n_dense:] == [f"dense_{k}" for k in range(n_dense)]
model.factor_names[-n_factors:] == [f"dense_{k}" for k in range(n_factors)]
).all()

for prior_mask in model.get_prior_masks().values():
assert prior_mask.shape[0] == model.n_factors
assert prior_mask[-n_dense:, :].all()
assert prior_mask[-n_factors:, :].all()

0 comments on commit eb9e6a8

Please sign in to comment.