Skip to content

Commit

Permalink
Adagrid, Tuning, Lewis, Lots of code. (#98)
Browse files Browse the repository at this point in the history
* Add skeleton lei stuff

* Add lei notebook for now and poisson process fun study

* Add stuff for ben to see

* Add current changes for Ben once again (so needy :P)

* Fixed the hang problem.

* Commit what I have

* Add currently broken code once again

* Simplify notebook

* Add working version of lewis

* Add current progress

* Add batching method and current logic of lei

* Move lei stuff to its own package

* Working on linear interpolation for the Lei problem.

* inlaw -> outlaw.

* Add lei test n_config

* Fix settings.json

* Add unit tests and update notebook with correct simulations

* Update test, add simulation tests, add point batcher, share RNG

* Update comment

* JAX implementation of scipy.interpolate.interpn (#47)

* JAX Interpolation.

* JAX implementation of scipy.interpolate.interpn

* Update todo list.

* Add current version lol

* Fix bugs and integrate good version

* Fix small bug in stage 2 and clean up code

* Modify interpn to work with multi-dimensional values

* Add current version of notebook

* WTF

* Finish final lei

* Fix test in outlaw

* Add python notebook (weird vscode lol)

* Add lei simulator batching method

* Remove unnecessary files cluttering up space

* Add current state

* Add upper bound logic to lei example

* Add ignore to frontend and update lei flow

* Clean up lewis code and include some of Ben's changes

* Add new script

* Add new changes to make memory ok

* Add full changes to everything except key

* Add checkpointing

* Add modified version

* First pass at holder-odi bound in binomial.py

* Holder-ODI, feeling more confident.

* Add analyze lei example

* Move lewis into confirm

* Fix analyze notebook with new import structure

* Add np.isnan check for holder bound and update lei analyze scripts

* Moving files, small tweaks.

* Pre-commit fixes.

* Most tests passing.

* Adagrid working nicely for validation, on to tuning.

* Working on adagrid + tuning.

* Seems to work.

* 3D berry adagrid + tuning.

* table caching for lewis.

* Fix test stage1.

* lewis adagrid.

* Lewis + Ada + Tuning.

* Symmetry

* 4D

* Fixing bugs in the cv_idx calculation and adagrid refinement criterion.

* Bootstrap tuning.

* 3D big job running.

* 3D run successful.

* Save.

* Plots for 3D Lei.

* Tweaks to cloud documentation. Also docker auto-start on reboot.

* Document s3 push.

* Little clean ups, leaving the big cleanups for later.

* Reduce the bootstrap output to avoid storing so many results.

* All tests fixed/passing.

* All tests fixed/passing.

* Remove the vertices array from grid. (#87)

* Remove the vertices array.

* vertices as property so that existing code keeps working.

* Fixes to incorporate tbt/ada_tuning stuff.

* Using exponential holder bound for tuning.

* Latest criterion.

* Running big 4D job with improved refinement rule and new EH bound and no tile vertices.

* Fast test_simulation.py (#89)

Speed up test_simulation.py

* Exponential checkpointing.

* Weird memory errors that might be due to the machine.

* Separated tuning from simulation.

* Working through the 4d lei job.

* Fixing up the criterion (inflation?) and running 4d lei.

* Refactoring the adagrid code.

* Impossibility...

* Moving to aws since the run seems to be going well.

* 4D Lei finally working beautifully. Lots of memory usage fixes.

* Working on Lewis figures.

* Rename inspector2

* Bug fixes in criterion.

* Inspector upgrades and a criterion that includes twb_mean_lam.

* Database exploration.

* Using tilt-bound in AdaRunner.

* Fix test_batch.

Co-authored-by: James Yang <jamesyang916@gmail.com>
  • Loading branch information
tbenthompson and JamesYang007 authored Nov 2, 2022
1 parent 804a1bc commit 6a818ce
Show file tree
Hide file tree
Showing 43 changed files with 9,547 additions and 1,031 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,7 @@ venv
# Cloud and AWS stuff
.terraform
.terraform.*
terraform.*
terraform.*

# explicitly ignore a file.
*.gitignore
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
"matrixfunctions": "cpp",
"bvh": "cpp"
},
"C_Cpp.errorSquiggles": "Enabled",
"C_Cpp.errorSquiggles": "enabled",
"C_Cpp.clang_format_fallbackStyle": "{ BasedOnStyle: LLVM, UseTab: Never, IndentWidth: 4, TabWidth: 4, AllowShortIfStatementsOnASingleLine: false, IndentCaseLabels: false, ColumnLimit: 100, AccessModifierOffset: -4, NamespaceIndentation: All, FixNamespaceComments: false, PointerAlignment: Left}",
"cmake.configureOnOpen": false,
"python.testing.unittestEnabled": false,
Expand Down
File renamed without changes
16 changes: 11 additions & 5 deletions cloud/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ We have some S3 Buckets. These contain various important data:

- `imprint-dump` - each subfolder here should contain the output of a model run.
- `aws-cloudtrail-logs-644171722153-2d03f9cb` - AWS access/management logs of everything we've done.
- `s3-access-logs-111222` - S3 access logs

Pushing data:
- To push a folder of data to an S3 bucket: `aws s3 sync ./source_foldername s3://imprint-dump/target_foldername`
- To push a single file to an S3 bucket: `aws s3 cp ./source_file s3://imprint-dump/target_file`

## Using VSCode Dev Containers

Expand Down Expand Up @@ -139,10 +142,13 @@ TODO: I think this is one of the remaining important tasks here. See the [issue

- Stop the instance using the AWS CLI or the Console
- Restart the instance using the AWS CLI or the Console
- `terraform apply` to update the terraform outputs (the public ipv4 DNS url will have changed)
- you might need to start docker... `./connect.sh` then `sudo systemctl start docker`. We could integrate this step into `./setup_remotedev.sh`.
- `./setup_remotedev.sh` to re-initalize the remote machine
- Open the docker sidebar in VSCode, start the relevant stopped container.
- Run `terraform apply` . Read the plan carefully to make sure that what you want to happen is going to happen. If you used a variable file when you created the instance, you need to pass the same variable file again here like `terraform apply -var-file="gpumachine.tfvars"`. In most cases, the only thing that will have changed is the public IPv4 DNS URL.
- You shouldn't need to start docker since it's setup to start automatically on boot. But, if you do: `./connect.sh` then `sudo systemctl start docker`.
- If connecting fails there are a few potential explanations:
1. Maybe you need to log in to AWS? `aws sso configure`
2. Maybe your ssh key is not being recognized. Try running `ssh-add --apple-use-keychain ~/.ssh/aws-key-pair.pem`
- `./setup_remotedev.sh` to re-initalize the remote docker context.
- Open the Docker sidebar in VSCode, start the relevant stopped container. It should have a name like `vsc-confirmasaurus-...`.
- Then, run the VSCode command "Dev Containers: Attach to running container".
- Once the container has launched, open the existing workspace folder inside the remote docker container. Probably `/workspaces/confirmasaurus`.

Expand Down
6 changes: 5 additions & 1 deletion cloud/devinstance/init_amzn_linux.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ distribution=$(. /etc/os-release;echo $ID$VERSION_ID) \
sudo yum install --disablerepo="*" --enablerepo="libnvidia-container" nvidia-container-toolkit -y

sudo service docker start
sudo usermod -a -G docker ec2-user
sudo usermod -a -G docker ec2-user

# set docker to start automatically on boot.
sudo systemctl enable docker.service
sudo systemctl enable containerd.service
2 changes: 2 additions & 0 deletions cloud/images/smalldev/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ RUN apt-get update \
software-properties-common \
dirmngr \
neovim \
cm-super \
dvipng \
&& ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh \
&& echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc \
&& echo "conda activate base" >> ~/.bashrc \
Expand Down
5 changes: 4 additions & 1 deletion confirm/confirm/berrylib/fast_inla.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ def rejection_inference(self, data, method="jax"):
_, exceedance, _, _ = self.inference(data, method)
return exceedance > self.critical_value

def test_inference(self, data, method="jax"):
_, exceedance, _, _ = self.inference(data, method)
return exceedance

def inference(self, data, method="jax"):
fncs = dict(
numpy=self.numpy_inference, jax=self.jax_inference, cpp=self.cpp_inference
Expand All @@ -173,7 +177,6 @@ def numpy_inference(self, data, thresh_theta=None):
if thresh_theta is None:
thresh_theta = self.thresh_theta

# TODO: warm start with DB theta ?
# Step 1) Compute the mode of p(theta, y, sigma^2) holding y and sigma^2 fixed.
# This is a simple Newton's method implementation.
theta_max, hess_inv = self.optimize_mode(data)
Expand Down
139 changes: 125 additions & 14 deletions confirm/confirm/lewislib/batch.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,32 @@
import jax.numpy as jnp
import numpy as np

# TODO: allow batch to decide internally what batch size to use??

def pad_arg__(a, axis, n_pad: int):

def _pad_arg(a, axis, n_pad: int):
"""
Pads an array:
- along the specified axis.
- with the values at index 0
- by n_pad elements.
Padding with the values at index 0 avoids problems with using a placeholder
value like 0 in situations where the placeholder value would be invalid.
"""
pad_element = np.take(a, indices=0, axis=axis)
pad_element = np.expand_dims(pad_element, axis=axis)
new_shape = tuple(a.shape[i] if i != axis else n_pad for i in range(a.ndim))
return np.concatenate((a, np.full(new_shape, pad_element)), axis=axis)


def create_batched_args__(args, in_axes, start, end, n_pad=None):
def _create_batched_args(args, in_axes, start, end, n_pad=None):
"""
Subsets and pads the arguments as specified in in_axes.
"""

def arg_transform(arg, axis):
return pad_arg__(arg, axis, n_pad) if n_pad is not None else arg
return _pad_arg(arg, axis, n_pad) if n_pad is not None else arg

return [
arg_transform(
Expand All @@ -23,7 +39,17 @@ def arg_transform(arg, axis):
]


def batch(f, batch_size: int, in_axes):
def batch_yield(f, batch_size: int, in_axes):
"""
A generator that yields batches of output from the function f.
Args:
f: The function to be batched.
batch_size: The batch size.
in_axes: For each argument, the axis along which to batch. If None, the
argument is not batched.
"""

def internal(*args):
dims = np.array(
[arg.shape[axis] for arg, axis in zip(args, in_axes) if axis is not None]
Expand All @@ -41,28 +67,37 @@ def internal(*args):
"along their corresopnding in_axes."
)

if len(args) != len(in_axes):
raise ValueError(
"The number of arguments must match the number of in_axes."
)

dim = dims[0]
batch_size_new = min(batch_size, dim)
n_full_batches = dim // batch_size_new
remainder = dim % batch_size_new
n_pad = batch_size_new - remainder

# NOTE: i don't think we should shrink the batch size because that'll
# incur extra JIT overhead when a user calls with lots of different
# small sizes. but we could make this a configurable behavior.
# batch_size_new = min(batch_size, dim)
n_full_batches = dim // batch_size
remainder = dim % batch_size
n_pad = batch_size - remainder
pad_last = remainder > 0
start = 0
end = batch_size_new
end = batch_size

for _ in range(n_full_batches):
batched_args = create_batched_args__(
batched_args = _create_batched_args(
args=args,
in_axes=in_axes,
start=start,
end=end,
)
yield (f(*batched_args), 0)
start += batch_size_new
end += batch_size_new
start += batch_size
end += batch_size

if pad_last:
batched_args = create_batched_args__(
batched_args = _create_batched_args(
args=args,
in_axes=in_axes,
start=start,
Expand All @@ -75,10 +110,86 @@ def internal(*args):


def batch_all(f, batch_size: int, in_axes):
f_batch = batch(f, batch_size, in_axes)
"""
A function wrapper that batches calls to f.
Args:
f: Function to be batched.
batch_size: The batch size.
in_axes: For each argument, the axis along which to batch. If None, the
argument is not batched.
Returns:
The batched results.
"""
f_batch = batch_yield(f, batch_size, in_axes)

def internal(*args):
outs = tuple(out for out in f_batch(*args))
return tuple(out[0] for out in outs), outs[-1][-1]

return internal


def batch(f, batch_size: int, in_axes, out_axes=None):
"""
Batch a function call and concatenate the output.
The API is intended to be similar to jax.vmap.
https://jax.readthedocs.io/en/latest/_modules/jax/_src/api.html#vmap
If the function has a single output, the output is concatenated along the
specified axis. If the function has multiple outputs, each output is
concatenated along the corresponding axis.
Args:
f: Function to be batched.
batch_size: The batch size.
in_axes: For each argument, the axis along which to batch. If None, the
argument is not batched.
out_axes: The axis along which to concatenate function outputs.
Defaults to None which will concatenate along the first axis.
Returns:
A concatenated array or a tuple of concatenated arrays.
"""
f_batch_all = batch_all(f, batch_size, in_axes)

def internal(*args):
outs, n_pad = f_batch_all(*args)

return_first = False
if isinstance(outs[0], np.ndarray) or isinstance(outs[0], jnp.DeviceArray):
return_first = True
outs = [[o] for o in outs]
internal_out_axes = (0,) if out_axes is None else out_axes
else:
internal_out_axes = (
out_axes
if out_axes is not None
else tuple(0 for _ in range(len(outs[0])))
)

def entry(i, j):
if j == len(outs) - 1 and n_pad > 0:
axis = internal_out_axes[i]
N = outs[-1][i].shape[axis]
return np.take(
outs[-1][i], np.r_[0 : N - n_pad], mode="clip", axis=axis
)
else:
return outs[j][i]

return_vals = [
np.concatenate(
[entry(i, j) for j in range(len(outs))],
axis=internal_out_axes[i],
)
for i in range(len(outs[0]))
]
if return_first:
return return_vals[0]
else:
return return_vals

return internal
Loading

0 comments on commit 6a818ce

Please sign in to comment.