Skip to content

Commit

Permalink
Merge pull request #20 from martius-lab/dev/fourier_exp
Browse files Browse the repository at this point in the history
Fourier exp for rebuttal
  • Loading branch information
AndReGeist authored Mar 28, 2024
2 parents c1d0cb3 + 88c18ac commit 2e361bf
Show file tree
Hide file tree
Showing 13 changed files with 184 additions and 91 deletions.
4 changes: 3 additions & 1 deletion hitchhiking_rotations/cfgs/cfg_pose_to_cube_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ def get_cfg_pose_to_cube_image(device):
cfg = {
"_target_": "hitchhiking_rotations.utils.Trainer",
"lr": 0.01,
"optimizer": "SGD",
"patience": 10,
"optimizer": "Adam",
"logger": "${logger}",
"verbose": "${verbose}",
"device": device,
Expand Down Expand Up @@ -50,6 +51,7 @@ def get_cfg_pose_to_cube_image(device):
"trainers": {
"r9_l2": {**cfg, **{"preprocess_input": "${u:flatten}", "model": "${model9}"}},
"r6_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_gramschmidt_f}", "model": "${model6}"}},
"quat_aug_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_quaternion_aug}", "model": "${model4}"}},
"quat_c_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_quaternion_canonical}", "model": "${model4}"}},
"quat_rf_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_quaternion_rand_flip}", "model": "${model4}"}},
"euler_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_euler}", "model": "${model3}"}},
Expand Down
6 changes: 4 additions & 2 deletions hitchhiking_rotations/cfgs/cfg_pose_to_fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
def get_cfg_pose_to_fourier(device, nb, nf):
cfg = {
"_target_": "hitchhiking_rotations.utils.Trainer",
"lr": 0.01,
"lr": 0.001,
"patience": 10,
"optimizer": "Adam",
"logger": "${logger}",
"verbose": "${verbose}",
Expand Down Expand Up @@ -40,7 +41,7 @@ def get_cfg_pose_to_fourier(device, nb, nf):
"val_data": {
"_target_": "hitchhiking_rotations.datasets.PoseToFourierDataset",
"mode": "val",
"dataset_size": 400,
"dataset_size": 200,
"device": device,
"nb": nb,
"nf": nf,
Expand All @@ -56,6 +57,7 @@ def get_cfg_pose_to_fourier(device, nb, nf):
"trainers": {
"r9_l2": {**cfg, **{"preprocess_input": "${u:flatten}", "model": "${model9}"}},
"r6_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_gramschmidt_f}", "model": "${model6}"}},
"quat_aug_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_quaternion_aug}", "model": "${model4}"}},
"quat_c_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_quaternion_canonical}", "model": "${model4}"}},
"quat_rf_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_quaternion_rand_flip}", "model": "${model4}"}},
"euler_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_euler}", "model": "${model3}"}},
Expand Down
Binary file added hitchhiking_rotations/datasets/.cube_data.swp
Binary file not shown.
83 changes: 61 additions & 22 deletions hitchhiking_rotations/datasets/fourier_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import roma
import jax
import jax.numpy as jnp
import equinox as eqx

jax.config.update("jax_default_device", jax.devices("cpu")[0])

from hitchhiking_rotations import HITCHHIKING_ROOT_DIR
from hitchhiking_rotations.utils import save_pickle, load_pickle
Expand Down Expand Up @@ -42,25 +47,31 @@ def __getitem__(self, idx):
return roma.unitquat_to_rotmat(self.quats[idx]).type(torch.float32), self.features[idx]


class random_fourier_function:
def __init__(self, n_basis, seed, A0=0.0, L=1.0):
np.random.seed(seed)
self.L = L
self.n_basis = n_basis
self.A0 = A0
self.A = np.random.normal(size=n_basis)
self.B = np.random.normal(size=n_basis)
self.matrix = np.random.normal(size=(1, 9))
def random_fourier_function(x, nb, seed):
key = jax.random.PRNGKey(seed)
key1, key2 = jax.random.split(key, 2)
A = jax.random.normal(key=key1, shape=(nb,))
B = jax.random.normal(key=key2, shape=(nb,))

model = eqx.nn.MLP(in_size=9, out_size=1, width_size=50, depth=1, key=jax.random.PRNGKey(42 + seed))

fFs = 0.0
input = model(x)
for k in range(len(A)):
fFs += A[k] * jnp.cos((k + 1) * jnp.pi * input) + B[k] * jnp.sin((k + 1) * jnp.pi * input)
return fFs


def input_to_fourier(x, seed):
model = eqx.nn.MLP(in_size=9, out_size=1, width_size=50, depth=1, key=jax.random.PRNGKey(42 + seed))
return model(x)

def __call__(self, x):
fFs = self.A0 / 2
for k in range(len(self.A)):
fFs = (
fFs
+ self.A[k] * np.cos((k + 1) * np.pi * np.matmul(self.matrix, x) / self.L)
+ self.B[k] * np.sin((k + 1) * np.pi * np.matmul(self.matrix, x) / self.L)
)
return fFs

def batch_normalize(arr):
mean = np.mean(arr, axis=0, keepdims=True)
std = np.std(arr, axis=0, keepdims=True)
std[std == 0] = 1
return (arr - mean) / std


def create_data(N_points, nb, seed):
Expand All @@ -69,20 +80,21 @@ def create_data(N_points, nb, seed):
Args:
N_points: Number of random rotations to generate
nb: Number of fourier basis that form the target function
seed: Used to randomly initialize fourier function coefficients
seed: Used to randomly initialize fourier function
Returns:
rots: Random rotations
features: Target function evaluated at rots
"""
np.random.seed(seed)
rots = Rotation.random(N_points)
inputs = rots.as_matrix().reshape(N_points, -1)
four_func = random_fourier_function(nb, seed)
features = np.apply_along_axis(four_func, 1, inputs)
features = np.array(jax.vmap(random_fourier_function, in_axes=[0, None, None])(inputs, nb, seed).reshape(-1, 1))
features = batch_normalize(features)
return rots.as_quat().astype(np.float32), features.astype(np.float32)


def plot_fourier_data(rotations, features):
"""Plot distribution of rotations and features."""
import pandas as pd
import seaborn as sns

Expand All @@ -99,5 +111,32 @@ def plot_fourier_data(rotations, features):
plt.show()


def plot_fourier_func(nb, seed):
"""Plot the target function."""
rots = Rotation.random(400)
inputs = rots.as_matrix().reshape(400, -1)
four_in = np.array(jax.vmap(input_to_fourier, [0, None])(inputs, seed))
features = np.array(jax.vmap(random_fourier_function, [0, None, None])(inputs, nb, seed))
features2 = batch_normalize(features)
sorted_indices = np.argsort(four_in, axis=0)

plt.figure()
plt.plot(four_in[sorted_indices].flatten(), features[sorted_indices].flatten(), linestyle="-", marker=None)
plt.plot(
four_in[sorted_indices].flatten(), features2[sorted_indices].flatten(), linestyle="-", color="red", marker=None
)
plt.title(f"nb: {nb}, seed: {seed}")
plt.show()


if __name__ == "__main__":
create_data(N_points=100, nb=2, seed=5)
# Analyze created data
for b in range(1, 6):
for s in range(0, 1):
# rots, features = create_data(N_points=100, nb=b, seed=s)
# data_stats(rots, features)
# plot_fourier_data(rots, features)
print("MLP PyTree used to create Fourier function inputs:")
model = eqx.nn.MLP(in_size=9, out_size=1, width_size=50, depth=1, key=jax.random.PRNGKey(42))
eqx.tree_pprint(model)
plot_fourier_func(b, s)
2 changes: 1 addition & 1 deletion hitchhiking_rotations/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# All rights reserved. Licensed under the MIT license.
# See LICENSE file in the project root for details.
#
from .models import MLP, CNN, MLPNetPCD
from .models import MLP, MLP2, CNN, MLPNetPCD
17 changes: 17 additions & 0 deletions hitchhiking_rotations/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ def forward(self, x):
return self.model(x)


class MLP2(nn.Module):
def __init__(self, input_dim, output_dim):
super(MLP2, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 200),
nn.ReLU(),
nn.Linear(200, 200),
nn.ReLU(),
nn.Linear(200, 200),
nn.ReLU(),
nn.Linear(200, output_dim),
)

def forward(self, x):
return self.model(x)


class CNN(nn.Module):
def __init__(self, input_dim, width, height):
super(CNN, self).__init__()
Expand Down
58 changes: 37 additions & 21 deletions hitchhiking_rotations/utils/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,28 @@
# All rights reserved. Licensed under the MIT license.
# See LICENSE file in the project root for details.
#
from .euler_helper import euler_angles_to_matrix, matrix_to_euler_angles
from hitchhiking_rotations.utils.euler_helper import euler_angles_to_matrix, matrix_to_euler_angles
import roma
import torch
from math import pi


def euler_to_rotmat(inp: torch.Tensor) -> torch.Tensor:
def euler_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor:
return euler_angles_to_matrix(inp.reshape(-1, 3), convention="XZY")


def quaternion_to_rotmat(inp: torch.Tensor) -> torch.Tensor:
def quaternion_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor:
# without normalization
# normalize first
x = inp.reshape(-1, 4)
return roma.unitquat_to_rotmat(x / x.norm(dim=1, keepdim=True))


def gramschmidt_to_rotmat(inp: torch.Tensor) -> torch.Tensor:
def gramschmidt_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor:
return roma.special_gramschmidt(inp.reshape(-1, 3, 2))


def symmetric_orthogonalization(x):
def symmetric_orthogonalization(x, **kwargs):
"""Maps 9D input vectors onto SO(3) via symmetric orthogonalization.
x: should have size [batch_size, 9]
Expand All @@ -40,65 +41,80 @@ def symmetric_orthogonalization(x):
return r


def procrustes_to_rotmat(inp: torch.Tensor) -> torch.Tensor:
def procrustes_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor:
return symmetric_orthogonalization(inp)
return roma.special_procrustes(inp.reshape(-1, 3, 3))


def rotvec_to_rotmat(inp: torch.Tensor) -> torch.Tensor:
def rotvec_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor:
return roma.rotvec_to_rotmat(inp.reshape(-1, 3))


# rotmat to x / maybe here reshape is missing


def rotmat_to_euler(base: torch.Tensor) -> torch.Tensor:
def rotmat_to_euler(base: torch.Tensor, **kwargs) -> torch.Tensor:
return matrix_to_euler_angles(base, convention="XZY")


def rotmat_to_quaternion(base: torch.Tensor) -> torch.Tensor:
def rotmat_to_quaternion(base: torch.Tensor, **kwargs) -> torch.Tensor:
return roma.rotmat_to_unitquat(base)


def rotmat_to_quaternion_rand_flip(base: torch.Tensor) -> torch.Tensor:
# we could duplicate the data and flip the quaternions on both sides
# def quat_aug_dataset(quats: np.ndarray, ixs):
# # quats: (N, M, .., 4)
# # return augmented inputs and quats
# return (np.concatenate((quats, -quats), axis=0), *np.concatenate((ixs, ixs), axis=0))

def rotmat_to_quaternion_rand_flip(base: torch.Tensor, **kwargs) -> torch.Tensor:
rep = roma.rotmat_to_unitquat(base)
rand_flipping = torch.rand(base.shape[0]) > 0.5
rep[rand_flipping] *= -1
return rep


def rotmat_to_quaternion_canonical(base: torch.Tensor) -> torch.Tensor:
def rotmat_to_quaternion_canonical(base: torch.Tensor, **kwargs) -> torch.Tensor:
rep = roma.rotmat_to_unitquat(base)
rep[rep[:, 3] < 0] *= -1
return rep


def rotmat_to_gramschmidt(base: torch.Tensor) -> torch.Tensor:
def rotmat_to_quaternion_aug(base: torch.Tensor, mode: str) -> torch.Tensor:
"""Performs memory-efficient quaternion augmentation by randomly
selecting half of the quaternions in the batch with scalar part
smaller than 0.1 and then multiplies them by -1.
"""
rep = rotmat_to_quaternion_canonical(base)

if mode == "train":
rep[(torch.rand(rep.size(0), device=rep.device) < 0.5) * (rep[:, 3] < 0.1)] *= -1

return rep


def rotmat_to_gramschmidt(base: torch.Tensor, **kwargs) -> torch.Tensor:
return base[:, :, :2]


def rotmat_to_gramschmidt_f(base: torch.Tensor) -> torch.Tensor:
def rotmat_to_gramschmidt_f(base: torch.Tensor, **kwargs) -> torch.Tensor:
return base[:, :, :2].reshape(-1, 6)


def rotmat_to_procrustes(base: torch.Tensor) -> torch.Tensor:
def rotmat_to_procrustes(base: torch.Tensor, **kwargs) -> torch.Tensor:
return base


def rotmat_to_rotvec(base: torch.Tensor) -> torch.Tensor:
def rotmat_to_rotvec(base: torch.Tensor, **kwargs) -> torch.Tensor:
return roma.rotmat_to_rotvec(base)


def rotmat_to_rotvec_canonical(base: torch.Tensor, **kwargs) -> torch.Tensor:
"""WARNING: THIS FUNCTION HAS NOT BEEN TESTED"""
rep = roma.rotmat_to_rotvec(base)
rep[rep[:, 2] < 0] = (1.0 - 2.0 * pi / rep[rep[:, 2] < 0].norm(dim=1, keepdim=True)) * rep[rep[:, 2] < 0]
return rep


def test_all():
from scipy.spatial.transform import Rotation
from torch import from_numpy as tr
import numpy as np
from torch import from_numpy as tr

rs = Rotation.random(1000)
euler = rs.as_euler("XZY", degrees=False)
Expand Down
6 changes: 3 additions & 3 deletions hitchhiking_rotations/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
# All rights reserved. Licensed under the MIT license.
# See LICENSE file in the project root for details.
#
def passthrough(*x):
def passthrough(*x, **kwargs):
if len(x) == 1:
return x[0]
return x


def flatten(x):
def flatten(x, **kwargs):
return x.reshape(x.shape[0], -1)


def n_3x3(x):
def n_3x3(x, **kwargs):
return x.reshape(-1, 3, 3)
19 changes: 10 additions & 9 deletions hitchhiking_rotations/utils/notation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@


class RotRep(Enum):
GSO = "$\mathbb{R}^6$+GSO"
SVD = "$\mathbb{R}^9$+SVD"
QUAT_C = "Quat$^+$"
QUAT = "Quat"
QUAT_RF = "Quat+RF"
EULER = "Euler"
EXP = "Exp"
ROTMAT = "$\mathbb{R}^9$"
RSIX = "$\mathbb{R}^6$"
GSO = r"$\mathbb{R}^6$+GSO"
SVD = r"$\mathbb{R}^9$+SVD"
QUAT_C = r"Quat$^+$"
QUAT = r"Quat"
QUAT_RF = r"Quat$^{\mathrm{RF}}$"
QUAT_AUG = r"Quat$^{\mathrm{a}{+}}$"
EULER = r"Euler"
EXP = r"Exp"
ROTMAT = r"$\mathbb{R}^9$"
RSIX = r"$\mathbb{R}^6$"

def __str__(self):
return "%s" % self.value
Loading

0 comments on commit 2e361bf

Please sign in to comment.