Skip to content

Commit

Permalink
Improve EH-Bound numerical stability (#88)
Browse files Browse the repository at this point in the history
Add more numerically stable version of bound.
  • Loading branch information
JamesYang007 authored Oct 20, 2022
1 parent c3725e2 commit 227ce65
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 25 deletions.
58 changes: 44 additions & 14 deletions confirm/confirm/mini_imprint/bound/binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@ def logistic(t):
return jnp.maximum(t, 0) + jnp.log(1 + jnp.exp(-jnp.abs(t)))


def logistic_secant(t, v, q, b):
"""
Numerically stable implementation of the secant of logistic defined by:
(logistic(t + q * v) - logistic(b)) / q
defined for all t, v in R and q > 0.
It is only numerically stable if t, b are not too large in magnitude
and q is sufficiently away from 0.
"""
t_div_q = t / q
ls_1 = jnp.maximum(t_div_q + v, 0) - jnp.maximum(b, 0) / q
ls_2 = jnp.log(1 + jnp.exp(-jnp.abs(t + q * v)))
ls_2 = ls_2 - jnp.log(1 + jnp.exp(-jnp.abs(b)))
ls_2 = ls_2 / q
return ls_1 + ls_2


def A(n, t):
"""
Log-partition function of a Bernoulli family with d-arms
Expand All @@ -17,6 +33,14 @@ def A(n, t):
return n * jnp.sum(logistic(t))


def A_secant(n, t, v, q, b):
"""
Numerically stable implementation of the secant of A:
(A(t + q * v) - A(b)) / q
"""
return n * jnp.sum(logistic_secant(t, v, q, b))


def dA(n, t):
"""
Gradient of the log-partition function A.
Expand Down Expand Up @@ -389,7 +413,7 @@ def objective(self, q, theta_0, v, a):
v: displacement from pivot point.
a: constant shift.
"""
return (self.A(theta_0 + q * v) - self.A(theta_0) - jnp.log(a)) / q
return self.A_secant(theta_0, v, q, theta_0) - jnp.log(a) / q

# ============================================================
# Members for optimization routine.
Expand All @@ -398,6 +422,9 @@ def objective(self, q, theta_0, v, a):
def A(self, t):
return A(self.n, t)

def A_secant(self, t, v, q, b):
return A_secant(self.n, t, v, q, b)

def dA(self, t):
return dA(self.n, t)

Expand Down Expand Up @@ -567,8 +594,8 @@ def objective(self, q, theta_0, v, a):

def _eval(q):
p = 1 / (1 - 1 / q)
A0 = self.A(theta_0)
slope_diff = (self.A(theta_0 + q * v) - A0) / q - (self.A(theta_0 + v) - A0)
slope_diff = self.A_secant(theta_0, v, q, theta_0)
slope_diff = slope_diff - self.A_secant(theta_0, v, 1, theta_0)
return p * (slope_diff - jnp.log(a))

return jax.lax.cond(
Expand All @@ -585,16 +612,19 @@ def _eval(q):
def A(self, t):
return A(self.n, t)

def A_secant(self, t, v, q, b):
return A_secant(self.n, t, v, q, b)

def dA(self, t):
return dA(self.n, t)

def phi_t(self, q, t, theta_0, v, a):
p_inv = 1 - 1 / q
A0 = self.A(theta_0)
return (
self.A(theta_0 + q * v)
- A0
- q * (self.A(theta_0 + v) - A0 + jnp.log(a))
- t * (q - 1)
return q * (
self.A_secant(theta_0, v, q, theta_0)
- (self.A(theta_0 + v) - A0 + jnp.log(a))
- t * p_inv
)

def dphi_t(self, q, t, theta_0, v, a):
Expand Down Expand Up @@ -738,8 +768,8 @@ def q_holder_bound_fwd(
v: d-array displacement vector.
f0: probability value at theta_0.
"""
A0 = A(n, theta_0)
expo = (A(n, theta_0 + q * v) - A0) / q - (A(n, theta_0 + v) - A0)
expo = A_secant(n, theta_0, v, q, theta_0)
expo = expo - A_secant(n, theta_0, v, 1, theta_0)
return f0 ** (1 - 1 / q) * jnp.exp(expo)


Expand Down Expand Up @@ -767,14 +797,14 @@ def q_holder_bound_bwd(
"""

def _bound(q):
p = q / (q - 1)
A0 = A(n, theta_0)
slope_diff = (A(n, theta_0 + q * v) - A0) / q - (A(n, theta_0 + v) - A0)
p = 1 / (1 - 1 / q)
slope_diff = A_secant(n, theta_0, v, q, theta_0)
slope_diff = slope_diff - A_secant(n, theta_0, v, 1, theta_0)
return (alpha * jnp.exp(-slope_diff)) ** p

return jax.lax.cond(
q <= 1,
lambda _: float(alpha >= 1),
lambda _: (alpha >= 1) + 0.0,
_bound,
q,
)
129 changes: 124 additions & 5 deletions docs/tutorial/q-holder-bound.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jupyter:
format_version: '1.3'
jupytext_version: 1.13.8
kernelspec:
display_name: Python 3.10.6 ('confirm')
display_name: Python 3.10.5 ('confirm')
language: python
name: python3
---
Expand Down Expand Up @@ -35,6 +35,8 @@ and requires non-trivial implementation to compute these bounds.
import jax
import jax.numpy as jnp
import numpy as np
import scipy
from functools import partial
import matplotlib.pyplot as plt
import pyimprint.grid as pygrid
import confirm.mini_imprint.grid as grid
Expand Down Expand Up @@ -159,18 +161,35 @@ Assume that we have access to the Type I Error values at each of the `thetas`.
```python
gr = pygrid.make_cartesian_grid_range(
size=100,
lower=-np.ones(theta_0.shape[0]),
upper=np.ones(theta_0.shape[0]),
lower=-np.ones(2),
upper=np.ones(2),
grid_sim_size=0, # dummy for now
)
thetas = gr.thetas().T
subset = thetas[:, 0] >= thetas[:, 1]
thetas = thetas[subset]
radii = gr.radii().T
f0s = np.full(thetas.shape[0], 0.025) # dummy values
radii = radii[subset]

# dummy function to output TIE of a simple test.
def simple_TIE(n, theta, alpha):
prob = scipy.special.expit(theta)
var = np.sum(n * prob * (1 - prob), axis=-1)
mean = n * (prob[:, 1] - prob[:, 0]) / np.sqrt(var)
z_crit = scipy.stats.norm.isf(alpha)
return scipy.stats.norm.sf(z_crit - mean)

f0s = simple_TIE(n, thetas, 0.025)
```

```python
sc = plt.scatter(thetas[:, 0], thetas[:, 1], c=f0s, marker='.')
plt.colorbar(sc)
```

```python
hypercube = grid.hypercube_vertices(
theta_0.shape[0],
thetas.shape[1],
)
make_vertices = jax.vmap(
lambda r: r * hypercube,
Expand All @@ -193,3 +212,103 @@ bounds = q_holder_bound_fwd_tile_jvmap(
n, thetas, vertices, f0s,
)
```

```python
sc = plt.scatter(thetas[:,0], thetas[:,1], c=bounds, marker='.')
plt.colorbar(sc)
```

## Tuning Step


In the tuning step, we invert the q-Holder bound and optimize a different objective.
The inversion gives us a bound on the TIE at `theta_0` in order to achieve level alpha in a tile.
Intuitively, we would like this inverted bound to be as large as possible (less conservative).
Hence, the `q` parameter allows us to maximize the bound for any given `v`.

```python
bwd_solver = binomial.BackwardQCPSolver(n=n)
```

```python
def _sanity_check_bwd_solver():
q_grid = np.linspace(1+1e-6, 4, 1000)
bound_vmap = jax.vmap(binomial.q_holder_bound_bwd, in_axes=(0, None, None, None, None))
bounds = bound_vmap(q_grid, n, theta_0, v, f0)
q_opt = bwd_solver.solve(theta_0, v, f0)
opt_bound = binomial.q_holder_bound_bwd(q_opt, n, theta_0, v, f0)
plt.plot(q_grid, bounds, '--')
plt.plot(q_opt, opt_bound, 'r.')
_sanity_check_bwd_solver()
```

Similar to the validation step, we will extend the optimization routine for the whole tile.

```python
def q_holder_bound_bwd_tile(n, theta_0, vs, f0, bwd_solver):
# vectorize bwd_solver over v
bwd_solver_vmap_v = jax.vmap(
bwd_solver.solve,
in_axes=(None, 0, None),
)
q_opts = bwd_solver_vmap_v(theta_0, vs, f0)

# compute bounds over q_opt, v
bound_vmap = jax.vmap(
binomial.q_holder_bound_bwd,
in_axes=(0, None, None, 0, None),
)
bounds = bound_vmap(
q_opts, n, theta_0, vs, f0,
)

# find the minimax of the bounds at corners
i_max = jnp.argmin(bounds)
return vs[i_max], q_opts[i_max], bounds[i_max]
```

```python
def invert_bound(alpha, theta_0, vertices, n):
v = vertices - theta_0
q_opt = jax.vmap(bwd_solver.solve, in_axes=(None, 0, None))(
theta_0, v, alpha
)
return jnp.min(jax.vmap(binomial.q_holder_bound_bwd, in_axes=(0, None, None, 0, None))(
q_opt, n, theta_0, v, alpha
))
```

```python
q_holder_bound_bwd_tile(
n, theta_0, vs, f0, bwd_solver
)
```

### Main Workflow


The main workflow is similar to that of validation step.
For each tile (and its corresponding simulation point),
we would like to optimize for $q$ based on the worse corner
of the backward bound.

The following is an example of a workflow.
We continue with the setup as in validation step.

```python
q_holder_bound_bwd_tile_jvmap = jax.jit(jax.vmap(
lambda n, t, v, f0: q_holder_bound_bwd_tile(n, t, v, f0, bwd_solver)[-1],
in_axes=(None, 0, 0, 0),
))
```

```python
bounds_bwd = q_holder_bound_bwd_tile_jvmap(
n, thetas, vertices, f0s,
)
```

```python
sc = plt.scatter(thetas[:,0], thetas[:,1], c=bounds_bwd, marker='.')
plt.colorbar(sc)
```
5 changes: 3 additions & 2 deletions imprint/.vscode/build.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/bin/bash
eval "$(conda shell.bash hook)"
conda activate imprint
bazel build -c opt --config gcc //python:pyimprint/core.so
cp -f ./bazel-bin/python/pyimprint/core.so python/pyimprint/
bazel build -c opt //python:pyimprint/core.so
rm -f python/pyimprint/core.so
cp ./bazel-bin/python/pyimprint/core.so python/pyimprint/
6 changes: 3 additions & 3 deletions research/q-holder-bound/qcp-solver.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.6 ('confirm')",
"display_name": "Python 3.10.5 ('confirm')",
"language": "python",
"name": "python3"
},
Expand All @@ -573,12 +573,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.5"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "5d574717a19d12573763700bcd6833eaae2108879723021a1c549979ef70be90"
"hash": "d8e1ca1b3fede25e3995e2b26ea544fa1b75b9a17984e6284a43c1dc286640dd"
}
}
},
Expand Down
2 changes: 1 addition & 1 deletion research/q-holder-bound/qcp-solver.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jupyter:
format_version: '1.3'
jupytext_version: 1.13.8
kernelspec:
display_name: Python 3.10.6 ('confirm')
display_name: Python 3.10.5 ('confirm')
language: python
name: python3
---
Expand Down

0 comments on commit 227ce65

Please sign in to comment.