Skip to content

Commit

Permalink
Update joint_logprob usage for newer AePPL
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jan 3, 2023
1 parent 96ea86d commit b8b9641
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Example
def logprob_fn(y):
return joint_logprob({Y_rv: y})
return joint_logprob(realized={Y_rv: y})[0]
# Build the transition kernel
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- numpy>=1.18.1
- scipy>=1.4.0
- aesara>=2.8.3
- aeppl>=0.0.38
- aeppl>=0.0.40
# Intel BLAS
- mkl
- mkl-service
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"numpy>=1.18.1",
"scipy>=1.4.0",
"aesara>=2.8.3",
"aeppl>=0.0.38",
"aeppl>=0.0.40",
],
tests_require=["pytest"],
long_description=open("README.md").read() if exists("README.md") else "",
Expand Down
8 changes: 4 additions & 4 deletions tests/test_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_warmup_scalar():
Y_rv = srng.normal(1, 2)

def logprob_fn(y: TensorVariable):
logprob = joint_logprob({Y_rv: y})
logprob, _ = joint_logprob(realized={Y_rv: y})
return logprob

y_vv = Y_rv.clone()
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_warmup_vector():
Y_rv = srng.multivariate_normal(loc, cov)

def logprob_fn(y: TensorVariable):
logprob = joint_logprob({Y_rv: y})
logprob, _ = joint_logprob(realized={Y_rv: y})
return logprob

y_vv = Y_rv.clone()
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_univariate_hmc(step_size, diverges):
Y_rv = srng.normal(1, 2)

def logprob_fn(y):
logprob = joint_logprob({Y_rv: y})
logprob, _ = joint_logprob(realized={Y_rv: y})
return logprob

kernel = hmc.new_kernel(srng, logprob_fn)
Expand Down Expand Up @@ -162,7 +162,7 @@ def multivariate_normal_model(srng):
Y_rv = srng.multivariate_normal(loc_tt, cov_tt)

def logprob_fn(y):
return joint_logprob({Y_rv: y})
return joint_logprob(realized={Y_rv: y})[0]

return (loc, scale, rho), Y_rv, logprob_fn

Expand Down

0 comments on commit b8b9641

Please sign in to comment.