Skip to content

Commit

Permalink
v0.1.2 Fix categorical initialization bug (#15)
Browse files Browse the repository at this point in the history
* fix categorical init bug

* update gitignore

* update README
  • Loading branch information
jan-engelmann authored Jul 15, 2024
1 parent 74b0940 commit 3f1abac
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,4 @@ cython_debug/
.DS_Store
data
*.out
assets
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# MixMIL
Code for the paper: [Mixed Models with Multiple Instance Learning](https://arxiv.org/abs/2311.02455)

Accepted at AISTATS 24 as an oral presentation.
Accepted at AISTATS 24 as an oral presentation & [Outstanding Student Paper Highlight](https://aistats.org/aistats2024/awards.html).

Please raise an issue for questions and bug-reports.
## Installation
Expand Down
2 changes: 1 addition & 1 deletion mixmil/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mixmil.model import MixMIL

__all__ = ["MixMIL", "utils", "likelihood", "posterior", "data", "simulation", "paths"]
__version__ = "0.1.1"
__version__ = "0.1.2"
4 changes: 2 additions & 2 deletions mixmil/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_lr_init_params(X, Y, b, Fiv):
penalty="l2",
multi_class="multinomial",
solver="lbfgs",
n_jobs=8,
n_jobs=1,
verbose=0,
random_state=42,
max_iter=1000,
Expand All @@ -76,7 +76,7 @@ def get_lr_init_params(X, Y, b, Fiv):
alpha = Fiv.dot(np.ones((Fiv.shape[1], 1))).dot(alpha) - b.dot(mu_beta)

# init prior
var_z = (mu_beta**2 + sd_beta**2).mean().reshape(1, 1)
var_z = (mu_beta**2 + sd_beta**2).mean(axis=0, keepdims=True)

return [torch.Tensor(el) for el in (mu_beta, sd_beta, var_z, alpha)]

Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ classifiers = [
]

[project.urls]
Homepage = "https://github.com/AIH-SGML/MixMIL"
"Bug Tracker" = "https://github.com/AIH-SGML/MixMIL/issues"
Discussions = "https://github.com/AIH-SGML/MixMIL/discussions"
Homepage = "https://github.com/AIH-SGML/mixmil"
"Bug Tracker" = "https://github.com/AIH-SGML/mixmil/issues"
Discussions = "https://github.com/AIH-SGML/mixmil/discussions"

[project.optional-dependencies]
experiments = ["anndata>=0.8.0", "jupyterlab>=3.0.0"]
Expand Down
2 changes: 2 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def mock_data_categorical():
def test_init_with_mean_model_binomial(mock_data_binomial):
Xs, F, Y = mock_data_binomial
model = MixMIL.init_with_mean_model(Xs, F, Y, likelihood="binomial", n_trials=2)
model.train(Xs, F, Y, n_epochs=3)
assert isinstance(model, MixMIL)
assert model.likelihood_name == "binomial"
assert model.n_trials == 2
Expand All @@ -35,6 +36,7 @@ def test_init_with_mean_model_binomial(mock_data_binomial):
def test_init_with_mean_model_categorical(mock_data_categorical):
Xs, F, Y = mock_data_categorical
model = MixMIL.init_with_mean_model(Xs, F, Y, likelihood="categorical")
model.train(Xs, F, Y, n_epochs=3)
assert isinstance(model, MixMIL)
assert model.likelihood_name == "categorical"
assert model.n_trials is None
Expand Down

0 comments on commit 3f1abac

Please sign in to comment.