Skip to content

Commit

Permalink
Fix correctness of bound optimization (#97)
Browse files Browse the repository at this point in the history
* Rename notebooks and integrate simple solvers for tile-based approaches.

* Add in new change for tilt-bound

* Add gaussian bias test (it works woohoo)

* Fix python version to 3.9

* Use 3.9.13 instead?

* Fix to python 3.10
  • Loading branch information
JamesYang007 authored Nov 1, 2022
1 parent 227ce65 commit 804a1bc
Show file tree
Hide file tree
Showing 12 changed files with 1,388 additions and 168 deletions.
218 changes: 216 additions & 2 deletions confirm/confirm/mini_imprint/bound/binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,150 @@ def solve(self, theta_0, v, a):
return self._solve(theta_0, v, a)[2]


def q_holder_bound_fwd(
def _simple_bisection(f, m, M, tol):
def _cond_fun(args):
m, M = args
return (M - m) / M > tol

def _body_fun(args):
m, M = args
x = jnp.linspace(m, M, 4)
y = f(x)
i_star = jnp.argmin(y)
new_min = jnp.where(
i_star <= 1,
m,
x[i_star - 1],
)
new_max = jnp.where(
i_star <= 1,
x[i_star + 1],
M,
)
return (
new_min,
new_max,
)

_init_val = (m, M)
m, M = jax.lax.while_loop(
_cond_fun,
_body_fun,
_init_val,
)
return (M + m) / 2.0


class BaseTileQCPSolver:
def __init__(self, n, m=1, M=1e7, tol=1e-5):
self.n = n
self.min = m
self.max = M
self.tol = tol


class TileForwardQCPSolver(BaseTileQCPSolver):
"""
Solves the following strictly quasi-convex optimization problem:
minimize_q max_{v \in S} L_v(q)
subject to q >= 1
where
L_v(q) = (psi(theta_0, v, q) - log(a)) / q - psi(theta_0, v, 1)
"""

def obj_v(self, theta_0, v, q, loga):
secq = A_secant(
self.n,
theta_0,
v,
q,
theta_0,
)
sec1 = A_secant(
self.n,
theta_0,
v,
1,
theta_0,
)
return secq - loga / q - sec1

def obj(self, theta_0, vs, q, loga):
_obj_each_vmap = jax.vmap(self.obj_v, in_axes=(None, 0, None, None))
return jnp.max(_obj_each_vmap(theta_0, vs, q, loga))

def obj_vmap(self, theta_0, vs, qs, loga):
return jax.vmap(
self.obj,
in_axes=(None, None, 0, None),
)(theta_0, vs, qs, loga)

def solve(self, theta_0, vs, a):
loga = jnp.log(a)
return jax.lax.cond(
loga < -1e10,
lambda: jnp.inf,
lambda: _simple_bisection(
lambda x: self.obj_vmap(theta_0, vs, x, loga),
self.min,
self.max,
self.tol,
),
)


class TileBackwardQCPSolver(BaseTileQCPSolver):
"""
Solves the following strictly quasi-convex optimization problem:
minimize_q max_{v \in S} L_v(q)
subject to q >= 1
where
L_v(q) = (q/(q-1)) * (psi(theta_0, v, q) / q - psi(theta_0, v, 1) - log(a))
"""

def obj_v(self, theta_0, v, q):
secq = A_secant(
self.n,
theta_0,
v,
q,
theta_0,
)
sec1 = A_secant(
self.n,
theta_0,
v,
1,
theta_0,
)
return secq - sec1

def obj(self, theta_0, vs, q, loga):
p = 1.0 / (1.0 - 1.0 / q)
_obj_each_vmap = jax.vmap(self.obj_v, in_axes=(None, 0, None))
return p * (jnp.max(_obj_each_vmap(theta_0, vs, q)) - loga)

def obj_vmap(self, theta_0, vs, qs, loga):
return jax.vmap(
self.obj,
in_axes=(None, None, 0, None),
)(theta_0, vs, qs, loga)

def solve(self, theta_0, vs, a):
loga = jnp.log(a)
return jax.lax.cond(
loga < -1e10,
lambda: jnp.inf,
lambda: _simple_bisection(
lambda x: self.obj_vmap(theta_0, vs, x, loga),
self.min,
self.max,
self.tol,
),
)


def tilt_bound_fwd(
q,
n,
theta_0,
Expand All @@ -773,7 +916,39 @@ def q_holder_bound_fwd(
return f0 ** (1 - 1 / q) * jnp.exp(expo)


def q_holder_bound_bwd(
def tilt_bound_fwd_tile(
q,
n,
theta_0,
vs,
f0,
):
"""
Computes the forward q-Holder bound given by:
f0 * max_{v in vs} exp[L(q) - (A(theta_0 + v) - A(theta_0))]
for fixed f0, n, theta_0, vs,
where L, A are as given in ForwardQCPSolver.
Parameters:
-----------
q: q parameter.
n: scalar Binomial size parameter.
theta_0: d-array pivot point.
vs: (k, d)-array of displacement vectors
denoting the corners of a rectangle.
f0: probability value at theta_0.
"""

def _expo(v):
expo = A_secant(n, theta_0, v, q, theta_0)
expo = expo - A_secant(n, theta_0, v, 1, theta_0)
return expo

max_expo = jnp.max(jax.vmap(_expo, in_axes=(0,))(vs))
return f0 ** (1 - 1 / q) * jnp.exp(max_expo)


def tilt_bound_bwd(
q,
n,
theta_0,
Expand Down Expand Up @@ -808,3 +983,42 @@ def _bound(q):
_bound,
q,
)


def tilt_bound_bwd_tile(
q,
n,
theta_0,
vs,
alpha,
):
"""
Computes the backward q-Holder bound given by:
max_{v in vs} exp(-L(q))
where L(q) is as given in BackwardQCPSolver.
Parameters:
-----------
q: q parameter.
n: scalar Binomial size parameter.
theta_0: d-array pivot point.
vs: (k, d)-array displacement from pivot point.
These represent the corners of the rectangular tile.
alpha: target level.
"""
p = 1 / (1 - 1 / q)

def _expo(v):
slope_diff = A_secant(n, theta_0, v, q, theta_0)
slope_diff = slope_diff - A_secant(n, theta_0, v, 1, theta_0)
return slope_diff

def _bound():
max_expo = jnp.max(jax.vmap(_expo, in_axes=(0,))(vs))
return (alpha * jnp.exp(-max_expo)) ** p

return jax.lax.cond(
q <= 1,
lambda: (alpha >= 1) + 0.0,
_bound,
)
Loading

0 comments on commit 804a1bc

Please sign in to comment.