Skip to content

Commit

Permalink
Merge pull request #337 from NathanielF/add_dag_fix_bug
Browse files Browse the repository at this point in the history
add Propensity Score DAG as Quasi-Experiment writeup
  • Loading branch information
drbenvincent authored May 7, 2024
2 parents 0ffe348 + 6ecb992 commit 6ced6a9
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 8 deletions.
10 changes: 8 additions & 2 deletions causalpy/pymc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,11 +453,17 @@ def fit(self, X, t, coords):
distributions. We overwrite the base method because the base method assumes
a variable y and we use t to indicate the treatment variable here.
"""
# Ensure random_seed is used in sample_prior_predictive() and
# sample_posterior_predictive() if provided in sample_kwargs.
random_seed = self.sample_kwargs.get("random_seed", None)

self.build_model(X, t, coords)
with self:
self.idata = pm.sample(**self.sample_kwargs)
self.idata.extend(pm.sample_prior_predictive())
self.idata.extend(pm.sample_prior_predictive(random_seed=random_seed))
self.idata.extend(
pm.sample_posterior_predictive(self.idata, progressbar=False)
pm.sample_posterior_predictive(
self.idata, progressbar=False, random_seed=random_seed
)
)
return self.idata
10 changes: 8 additions & 2 deletions causalpy/tests/test_pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ def test_regression_kink_gradient_change():

def test_inverse_prop():
df = cp.load_data("nhefs")
sample_kwargs = {"tune": 100, "draws": 100, "chains": 2, "cores": 2}
sample_kwargs = {
"tune": 100,
"draws": 500,
"chains": 2,
"cores": 2,
"random_seed": 100,
}
result = cp.pymc_experiments.InversePropensityWeighting(
df,
formula="trt ~ 1 + age + race",
Expand Down Expand Up @@ -93,7 +99,7 @@ def test_inverse_prop():
assert isinstance(ate_list, list)
ate_list = result.get_ate(0, result.idata, method="overlap")
assert isinstance(ate_list, list)
fig = result.plot_ATE(prop_draws=10, ate_draws=10)
fig = result.plot_ATE(prop_draws=1, ate_draws=10)
assert isinstance(fig, mpl.figure.Figure)
fig = result.plot_balance_ecdf("age")
assert isinstance(fig, mpl.figure.Figure)
58 changes: 58 additions & 0 deletions docs/source/_static/interrogate_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
89 changes: 85 additions & 4 deletions docs/source/quasi_dags.ipynb

Large diffs are not rendered by default.

0 comments on commit 6ced6a9

Please sign in to comment.