diff --git a/confirm/confirm/mini_imprint/bound/binomial.py b/confirm/confirm/mini_imprint/bound/binomial.py index f1061d45..65003152 100644 --- a/confirm/confirm/mini_imprint/bound/binomial.py +++ b/confirm/confirm/mini_imprint/bound/binomial.py @@ -9,6 +9,22 @@ def logistic(t): return jnp.maximum(t, 0) + jnp.log(1 + jnp.exp(-jnp.abs(t))) +def logistic_secant(t, v, q, b): + """ + Numerically stable implementation of the secant of logistic defined by: + (logistic(t + q * v) - logistic(b)) / q + defined for all t, v in R and q > 0. + It is only numerically stable if t, b are not too large in magnitude + and q is sufficiently away from 0. + """ + t_div_q = t / q + ls_1 = jnp.maximum(t_div_q + v, 0) - jnp.maximum(b, 0) / q + ls_2 = jnp.log(1 + jnp.exp(-jnp.abs(t + q * v))) + ls_2 = ls_2 - jnp.log(1 + jnp.exp(-jnp.abs(b))) + ls_2 = ls_2 / q + return ls_1 + ls_2 + + def A(n, t): """ Log-partition function of a Bernoulli family with d-arms @@ -17,6 +33,14 @@ def A(n, t): return n * jnp.sum(logistic(t)) +def A_secant(n, t, v, q, b): + """ + Numerically stable implementation of the secant of A: + (A(t + q * v) - A(b)) / q + """ + return n * jnp.sum(logistic_secant(t, v, q, b)) + + def dA(n, t): """ Gradient of the log-partition function A. @@ -389,7 +413,7 @@ def objective(self, q, theta_0, v, a): v: displacement from pivot point. a: constant shift. """ - return (self.A(theta_0 + q * v) - self.A(theta_0) - jnp.log(a)) / q + return self.A_secant(theta_0, v, q, theta_0) - jnp.log(a) / q # ============================================================ # Members for optimization routine. @@ -398,6 +422,9 @@ def objective(self, q, theta_0, v, a): def A(self, t): return A(self.n, t) + def A_secant(self, t, v, q, b): + return A_secant(self.n, t, v, q, b) + def dA(self, t): return dA(self.n, t) @@ -567,8 +594,8 @@ def objective(self, q, theta_0, v, a): def _eval(q): p = 1 / (1 - 1 / q) - A0 = self.A(theta_0) - slope_diff = (self.A(theta_0 + q * v) - A0) / q - (self.A(theta_0 + v) - A0) + slope_diff = self.A_secant(theta_0, v, q, theta_0) + slope_diff = slope_diff - self.A_secant(theta_0, v, 1, theta_0) return p * (slope_diff - jnp.log(a)) return jax.lax.cond( @@ -585,16 +612,19 @@ def _eval(q): def A(self, t): return A(self.n, t) + def A_secant(self, t, v, q, b): + return A_secant(self.n, t, v, q, b) + def dA(self, t): return dA(self.n, t) def phi_t(self, q, t, theta_0, v, a): + p_inv = 1 - 1 / q A0 = self.A(theta_0) - return ( - self.A(theta_0 + q * v) - - A0 - - q * (self.A(theta_0 + v) - A0 + jnp.log(a)) - - t * (q - 1) + return q * ( + self.A_secant(theta_0, v, q, theta_0) + - (self.A(theta_0 + v) - A0 + jnp.log(a)) + - t * p_inv ) def dphi_t(self, q, t, theta_0, v, a): @@ -738,8 +768,8 @@ def q_holder_bound_fwd( v: d-array displacement vector. f0: probability value at theta_0. """ - A0 = A(n, theta_0) - expo = (A(n, theta_0 + q * v) - A0) / q - (A(n, theta_0 + v) - A0) + expo = A_secant(n, theta_0, v, q, theta_0) + expo = expo - A_secant(n, theta_0, v, 1, theta_0) return f0 ** (1 - 1 / q) * jnp.exp(expo) @@ -767,14 +797,14 @@ def q_holder_bound_bwd( """ def _bound(q): - p = q / (q - 1) - A0 = A(n, theta_0) - slope_diff = (A(n, theta_0 + q * v) - A0) / q - (A(n, theta_0 + v) - A0) + p = 1 / (1 - 1 / q) + slope_diff = A_secant(n, theta_0, v, q, theta_0) + slope_diff = slope_diff - A_secant(n, theta_0, v, 1, theta_0) return (alpha * jnp.exp(-slope_diff)) ** p return jax.lax.cond( q <= 1, - lambda _: float(alpha >= 1), + lambda _: (alpha >= 1) + 0.0, _bound, q, ) diff --git a/docs/tutorial/q-holder-bound.md b/docs/tutorial/q-holder-bound.md index 48c7b403..71bf2143 100644 --- a/docs/tutorial/q-holder-bound.md +++ b/docs/tutorial/q-holder-bound.md @@ -7,7 +7,7 @@ jupyter: format_version: '1.3' jupytext_version: 1.13.8 kernelspec: - display_name: Python 3.10.6 ('confirm') + display_name: Python 3.10.5 ('confirm') language: python name: python3 --- @@ -35,6 +35,8 @@ and requires non-trivial implementation to compute these bounds. import jax import jax.numpy as jnp import numpy as np +import scipy +from functools import partial import matplotlib.pyplot as plt import pyimprint.grid as pygrid import confirm.mini_imprint.grid as grid @@ -159,18 +161,35 @@ Assume that we have access to the Type I Error values at each of the `thetas`. ```python gr = pygrid.make_cartesian_grid_range( size=100, - lower=-np.ones(theta_0.shape[0]), - upper=np.ones(theta_0.shape[0]), + lower=-np.ones(2), + upper=np.ones(2), grid_sim_size=0, # dummy for now ) thetas = gr.thetas().T +subset = thetas[:, 0] >= thetas[:, 1] +thetas = thetas[subset] radii = gr.radii().T -f0s = np.full(thetas.shape[0], 0.025) # dummy values +radii = radii[subset] + +# dummy function to output TIE of a simple test. +def simple_TIE(n, theta, alpha): + prob = scipy.special.expit(theta) + var = np.sum(n * prob * (1 - prob), axis=-1) + mean = n * (prob[:, 1] - prob[:, 0]) / np.sqrt(var) + z_crit = scipy.stats.norm.isf(alpha) + return scipy.stats.norm.sf(z_crit - mean) + +f0s = simple_TIE(n, thetas, 0.025) +``` + +```python +sc = plt.scatter(thetas[:, 0], thetas[:, 1], c=f0s, marker='.') +plt.colorbar(sc) ``` ```python hypercube = grid.hypercube_vertices( - theta_0.shape[0], + thetas.shape[1], ) make_vertices = jax.vmap( lambda r: r * hypercube, @@ -193,3 +212,103 @@ bounds = q_holder_bound_fwd_tile_jvmap( n, thetas, vertices, f0s, ) ``` + +```python +sc = plt.scatter(thetas[:,0], thetas[:,1], c=bounds, marker='.') +plt.colorbar(sc) +``` + +## Tuning Step + + +In the tuning step, we invert the q-Holder bound and optimize a different objective. +The inversion gives us a bound on the TIE at `theta_0` in order to achieve level alpha in a tile. +Intuitively, we would like this inverted bound to be as large as possible (less conservative). +Hence, the `q` parameter allows us to maximize the bound for any given `v`. + +```python +bwd_solver = binomial.BackwardQCPSolver(n=n) +``` + +```python +def _sanity_check_bwd_solver(): + q_grid = np.linspace(1+1e-6, 4, 1000) + bound_vmap = jax.vmap(binomial.q_holder_bound_bwd, in_axes=(0, None, None, None, None)) + bounds = bound_vmap(q_grid, n, theta_0, v, f0) + q_opt = bwd_solver.solve(theta_0, v, f0) + opt_bound = binomial.q_holder_bound_bwd(q_opt, n, theta_0, v, f0) + plt.plot(q_grid, bounds, '--') + plt.plot(q_opt, opt_bound, 'r.') +_sanity_check_bwd_solver() +``` + +Similar to the validation step, we will extend the optimization routine for the whole tile. + +```python +def q_holder_bound_bwd_tile(n, theta_0, vs, f0, bwd_solver): + # vectorize bwd_solver over v + bwd_solver_vmap_v = jax.vmap( + bwd_solver.solve, + in_axes=(None, 0, None), + ) + q_opts = bwd_solver_vmap_v(theta_0, vs, f0) + + # compute bounds over q_opt, v + bound_vmap = jax.vmap( + binomial.q_holder_bound_bwd, + in_axes=(0, None, None, 0, None), + ) + bounds = bound_vmap( + q_opts, n, theta_0, vs, f0, + ) + + # find the minimax of the bounds at corners + i_max = jnp.argmin(bounds) + return vs[i_max], q_opts[i_max], bounds[i_max] +``` + +```python +def invert_bound(alpha, theta_0, vertices, n): + v = vertices - theta_0 + q_opt = jax.vmap(bwd_solver.solve, in_axes=(None, 0, None))( + theta_0, v, alpha + ) + return jnp.min(jax.vmap(binomial.q_holder_bound_bwd, in_axes=(0, None, None, 0, None))( + q_opt, n, theta_0, v, alpha + )) +``` + +```python +q_holder_bound_bwd_tile( + n, theta_0, vs, f0, bwd_solver +) +``` + +### Main Workflow + + +The main workflow is similar to that of validation step. +For each tile (and its corresponding simulation point), +we would like to optimize for $q$ based on the worse corner +of the backward bound. + +The following is an example of a workflow. +We continue with the setup as in validation step. + +```python +q_holder_bound_bwd_tile_jvmap = jax.jit(jax.vmap( + lambda n, t, v, f0: q_holder_bound_bwd_tile(n, t, v, f0, bwd_solver)[-1], + in_axes=(None, 0, 0, 0), +)) +``` + +```python +bounds_bwd = q_holder_bound_bwd_tile_jvmap( + n, thetas, vertices, f0s, +) +``` + +```python +sc = plt.scatter(thetas[:,0], thetas[:,1], c=bounds_bwd, marker='.') +plt.colorbar(sc) +``` diff --git a/imprint/.vscode/build.sh b/imprint/.vscode/build.sh index 70f3c4af..315cceab 100644 --- a/imprint/.vscode/build.sh +++ b/imprint/.vscode/build.sh @@ -1,5 +1,6 @@ #!/bin/bash eval "$(conda shell.bash hook)" conda activate imprint -bazel build -c opt --config gcc //python:pyimprint/core.so -cp -f ./bazel-bin/python/pyimprint/core.so python/pyimprint/ \ No newline at end of file +bazel build -c opt //python:pyimprint/core.so +rm -f python/pyimprint/core.so +cp ./bazel-bin/python/pyimprint/core.so python/pyimprint/ \ No newline at end of file diff --git a/research/q-holder-bound/qcp-solver.ipynb b/research/q-holder-bound/qcp-solver.ipynb index f2405bf9..8f85bb58 100644 --- a/research/q-holder-bound/qcp-solver.ipynb +++ b/research/q-holder-bound/qcp-solver.ipynb @@ -559,7 +559,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.10.6 ('confirm')", + "display_name": "Python 3.10.5 ('confirm')", "language": "python", "name": "python3" }, @@ -573,12 +573,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.5" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "5d574717a19d12573763700bcd6833eaae2108879723021a1c549979ef70be90" + "hash": "d8e1ca1b3fede25e3995e2b26ea544fa1b75b9a17984e6284a43c1dc286640dd" } } }, diff --git a/research/q-holder-bound/qcp-solver.md b/research/q-holder-bound/qcp-solver.md index c9437a73..3d096b9e 100644 --- a/research/q-holder-bound/qcp-solver.md +++ b/research/q-holder-bound/qcp-solver.md @@ -7,7 +7,7 @@ jupyter: format_version: '1.3' jupytext_version: 1.13.8 kernelspec: - display_name: Python 3.10.6 ('confirm') + display_name: Python 3.10.5 ('confirm') language: python name: python3 ---