Skip to content

Commit

Permalink
Merge branch 'main' of github.com:martius-lab/hitchhiking-rotations i…
Browse files Browse the repository at this point in the history
…nto main
  • Loading branch information
jotix16 committed Mar 26, 2024
2 parents 24c57de + 58987bc commit c1d0cb3
Show file tree
Hide file tree
Showing 10 changed files with 478 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ repos:
rev: 7.0.0
hooks:
- id: flake8
args: [--max-line-length=120, "--ignore=W291,E731"]
args: [--max-line-length=120, "--ignore=W291,E731,,F401,F403"]
1 change: 1 addition & 0 deletions hitchhiking_rotations/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# All rights reserved. Licensed under the MIT license.
# See LICENSE file in the project root for details.
#
from .colors import gigachad_colors
from .euler_helper import euler_angles_to_matrix, matrix_to_euler_angles
from .conversions import *
from .metrics import *
Expand Down
9 changes: 9 additions & 0 deletions hitchhiking_rotations/utils/colors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
gigachad_colors = [
(0.368, 0.507, 0.71),
(0.881, 0.611, 0.142),
(0.923, 0.386, 0.209),
(0.56, 0.692, 0.195),
(0.528, 0.471, 0.701),
(0.772, 0.432, 0.102),
(0.572, 0.586, 0.0),
]
3 changes: 1 addition & 2 deletions hitchhiking_rotations/utils/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def rotmat_to_rotvec(base: torch.Tensor) -> torch.Tensor:

def test_all():
from scipy.spatial.transform import Rotation
from torch import from_numpy as tr
import numpy as np

rs = Rotation.random(1000)
Expand All @@ -106,8 +107,6 @@ def test_all():
quat_hm = np.where(quat[:, 3:4] < 0, -quat, quat)
rotvec = rs.as_rotvec()

tr = lambda x: torch.from_numpy(x)

# euler_to_rotmat
print(np.allclose(euler_to_rotmat(tr(euler)).numpy(), rot))
print(np.allclose(quaternion_to_rotmat(tr(quat)).numpy(), rot))
Expand Down
43 changes: 43 additions & 0 deletions hitchhiking_rotations/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,46 @@ def validation_epoch_finish(self, epoch):

def training_finish(self):
self.model.load_state_dict(self.early_stopper.best_state_dict)

@torch.no_grad()
def test_batch_time(self, x, target, epoch, mode):
self.model.eval()

t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)
t2 = torch.cuda.Event(enable_timing=True)
t3 = torch.cuda.Event(enable_timing=True)
t4 = torch.cuda.Event(enable_timing=True)
t5 = torch.cuda.Event(enable_timing=True)
t6 = torch.cuda.Event(enable_timing=True)

torch.cuda.synchronize()
t0.record()
x = self.preprocess_input(x) # Step 0
torch.cuda.synchronize()
t1.record()
pred = self.model(x) # Step 1
torch.cuda.synchronize()
t2.record()
pred_loss = self.postprocess_pred_loss(pred) # Step 2
torch.cuda.synchronize()
t3.record()
pp_target = self.preprocess_target(target) # Step 3
torch.cuda.synchronize()
t4.record()
loss = self.loss(pred_loss, pp_target) # Step 4
torch.cuda.synchronize()
t5.record()
_ = self.postprocess_pred_logging(pred) # Step 5
torch.cuda.synchronize()
t6.record()
torch.cuda.synchronize()

return [
t0.elapsed_time(t1),
t1.elapsed_time(t2),
t2.elapsed_time(t3),
t3.elapsed_time(t4),
t4.elapsed_time(t5),
t5.elapsed_time(t6),
], loss
25 changes: 25 additions & 0 deletions results/dense_fusion/dense_fusion_experiment.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
,method,score,metric
0,$\mathbb{R}^9$+SVD,91.04188604844259,AUC ADD-S
1,$\mathbb{R}^9$+SVD,95.07767991069852,<2cm
2,$\mathbb{R}^9$+SVD,91.3637847207741,AUC ADD-S
3,$\mathbb{R}^9$+SVD,94.59201274202522,<2cm
4,$\mathbb{R}^9$+SVD,91.08496590464102,AUC ADD-S
5,$\mathbb{R}^9$+SVD,94.61647270550057,<2cm
6,$\mathbb{R}^6$+GSO,91.3470127729485,AUC ADD-S
7,$\mathbb{R}^6$+GSO,94.19720889771384,<2cm
8,$\mathbb{R}^6$+GSO,91.42867218065841,AUC ADD-S
9,$\mathbb{R}^6$+GSO,94.3550012793383,<2cm
10,$\mathbb{R}^6$+GSO,90.77961308438314,AUC ADD-S
11,$\mathbb{R}^6$+GSO,95.43328935233129,<2cm
12,Quat$^+$,91.3210065650409,AUC ADD-S
13,Quat$^+$,95.46569600702125,<2cm
14,Quat$^+$,90.9166210514394,AUC ADD-S
15,Quat$^+$,94.20788355159773,<2cm
16,Quat$^+$,91.32054850636075,AUC ADD-S
17,Quat$^+$,94.56954022581728,<2cm
18,Euler,90.77298990957122,AUC ADD-S
19,Euler,92.94620768691362,<2cm
20,Euler,90.63536905144662,AUC ADD-S
21,Euler,93.23661241181223,<2cm
22,Euler,90.82663296670404,AUC ADD-S
23,Euler,93.52104289463712,<2cm
35 changes: 35 additions & 0 deletions tests/test_svd_timing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from pytictac import Timer, CpuTimer
import torch


def test_svd():
BS = 128
repeats = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
tim = Timer
else:
tim = CpuTimer

for i in range(repeats):
m = torch.rand((BS, 3, 3), device=device)
u, s, v = torch.svd(m)

with tim("SVD"):
for i in range(repeats):
m = torch.rand((BS, 3, 3), device=device)
u, s, v = torch.svd(m)

m = torch.rand((BS, 3, 3), device=device)
with tim("SVD"):
for i in range(repeats):
u, s, v = torch.svd(m)

m = torch.rand((BS, 3, 3), device=device)
with tim("SVD single"):
u, s, v = torch.svd(m)


if __name__ == "__main__":
test_svd()
117 changes: 117 additions & 0 deletions visu/figure_13.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from hitchhiking_rotations import HITCHHIKING_ROOT_DIR
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns


def generate_colormap(base_color, lvl=0.2, alpha=0.5):
# Convert base color to RGB values between 0 and 1
base_color = np.array(base_color) / 255.0

# Define the color at the extremes (0 and 1)
lighter_color = base_color + lvl
darker_color = base_color - lvl

# Ensure colors are within valid range
lighter_color = np.clip(lighter_color, 0, 1)
darker_color = np.clip(darker_color, 0, 1)

# Generate colormap
colors = [lighter_color, base_color, darker_color]
positions = [0, alpha, 1]
return LinearSegmentedColormap.from_list("custom_cmap", list(zip(positions, colors)))


o = generate_colormap([224, 157, 52, 255])
b = generate_colormap([57, 84, 122, 255])


# Load data
df = pd.read_csv(os.path.join(HITCHHIKING_ROOT_DIR, "results", "dense_fusion", "dense_fusion_experiment.csv"))

plt.style.use(os.path.join(HITCHHIKING_ROOT_DIR, "assets", "prettyplots.mplstyle"))
sns.set_style("whitegrid")


# Define symbols for each method
method_symbols = {
"$\mathbb{R}^9$+SVD": "D", # square
"$\mathbb{R}^6$+GSO": "h", # diamond
"Quat$^+$": "*", # circle
"Euler": "X", # cross
}

# Define colors for each metric
metric_colors = {
"AUC ADD-S": o,
"<2cm": b,
}

# Set up colormap for gradient based on scores
# Get unique methods and assign x-values to them
unique_methods = df["method"].unique()
method_indices = {method: idx for idx, method in enumerate(unique_methods)}

i = 0
# Plotting
fig = plt.figure(figsize=(5, 3))
for method, symbol in method_symbols.items():
for metric, color in metric_colors.items():
sub_df = df[(df["method"] == method) & (df["metric"] == metric)]
scores = sub_df["score"]

min_v = df[(df["metric"] == metric)]["score"].min()
max_v = df[(df["metric"] == metric)]["score"].max()
# Normalize scores for gradient colormap
normalized_scores = (scores - min_v.min()) / (max_v.max() - min_v.min())
x_values = [method_indices[method]] * len(sub_df.index)
x_values += np.random.uniform(-0.15, 0.15, len(x_values)) # Add jitter to x-values
plt.scatter(x_values, scores, c=color(normalized_scores), marker=symbol, edgecolor="black", linewidth=0.5, s=70)

# Markers for legend - put them on negative y-axis
if i == 0:
plt.scatter(
x_values,
scores * -1,
c=o([0.5] * len(scores)),
marker="s",
label="AUC ADD-S",
edgecolor="black",
linewidth=0.5,
s=70,
)
plt.scatter(
x_values,
scores * -1,
c=b([0.5] * len(scores)),
marker="s",
label="<2cm",
edgecolor="black",
linewidth=0.5,
s=70,
)
i = 1

# Limity y-axis to not see the markers on negative side
fig.axes[0].set_ylim([90.2, 95.8])

plt.legend(
title="",
bbox_to_anchor=(0.5, 1.15),
loc="upper center",
ncol=len(method_symbols),
frameon=False,
borderaxespad=0.0,
handletextpad=0.5,
markerscale=1.0,
)
plt.xticks(range(len(unique_methods)), unique_methods) # Set ticks to method names
plt.ylabel("Score")
plt.grid(True, linestyle="--", color="gray", alpha=0.5)

plt.tight_layout()
plt.savefig(os.path.join(HITCHHIKING_ROOT_DIR, "results", "dense_fusion", "figure_13.pdf"))
plt.show()
125 changes: 125 additions & 0 deletions visu/time_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from hitchhiking_rotations import HITCHHIKING_ROOT_DIR
from hitchhiking_rotations.utils import save_pickle
from hitchhiking_rotations.utils import RotRep, gigachad_colors
from hitchhiking_rotations.cfgs import get_cfg_cube_image_to_pose
import numpy as np
import argparse
import os
import hydra
from omegaconf import OmegaConf
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


parser = argparse.ArgumentParser()

parser.add_argument(
"--seed",
type=int,
default=0,
help="Random seed used during training, " + "for pose_to_fourier the seed is used to select the target function.",
)
args = parser.parse_args()

s = args.seed
torch.manual_seed(s)
np.random.seed(s)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()
torch.zeros((1,), device=device) # Initialize CUDA

cfg_exp = get_cfg_cube_image_to_pose(device)
OmegaConf.register_new_resolver("u", lambda x: hydra.utils.get_method("hitchhiking_rotations.utils." + x))
cfg_exp = OmegaConf.create(cfg_exp)

trainers = hydra.utils.instantiate(cfg_exp.trainers)
test_data = hydra.utils.instantiate(cfg_exp.test_data)
# Create dataloaders
epoch = 0

timing_result = {}
batch_sizes = [1, 32, 256, 1024]
for batch_size in batch_sizes:
test_dataloader = DataLoader(test_data, num_workers=0, batch_size=batch_size, shuffle=True)
# Perform testing
for j, batch in enumerate(test_dataloader):
x, target = batch
for trainer_name, pretty_name in zip(
["r9_svd_geodesic_distance", "r9_geodesic_distance", "r6_gso_geodesic_distance"],
[str(s) for s in [RotRep.SVD, RotRep.ROTMAT, RotRep.GSO]],
):
trainer = trainers[trainer_name]

for i in range(100):
trainer.test_batch(x.clone(), target.clone(), epoch, mode="test")

timing_result[f"{pretty_name} \n BS-{batch_size}"] = []
for i in range(100):
rand = torch.rand_like(x) * 0.00001
res, _ = trainer.test_batch_time(x + rand, target, epoch, mode="test")
timing_result[f"{pretty_name} \n BS-{batch_size}"].append(res)

for k in timing_result.keys():
times = np.array(timing_result[k])
print(times.shape)
print(k, " mean t0-t5 ", times.mean(axis=0))
print(k, " std t0-t5 ", times.std(axis=0))

break


# Chat GPT visualization

# Extract timing data for each method
method_names = list(timing_result.keys())
sub_timings = np.array([timing_result[k] for k in method_names])

# Calculate means for each subtiming
means = sub_timings.mean(axis=1)

# Create stacked bar plot
fig, ax = plt.subplots(figsize=(12, 6))

bar_width = 0.8

index = []
c = 0
for b in range(len(batch_sizes)):
for i in range(len(method_names) // len(batch_sizes)):
index.append(c)
c += 1
c += 0.5
index = np.array(index)

plots = []
bottom = np.zeros(len(method_names))

subtiming_labels = [
"preprocess_input",
"model_forward",
"postprocess_pred_loss",
"preprocess_target",
"loss",
"postprocess_pred_logging",
]

for i, label in enumerate(subtiming_labels):
plot = ax.bar(index, means[:, i], bar_width, bottom=bottom, label=label, color=gigachad_colors[i])
bottom += means[:, i]
plots.append(plot)

ax.set_xlabel("Methods")
ax.set_ylabel("Time in ms")
ax.set_title("Timing Results")
ax.set_xticks(index)
ax.set_xticklabels(method_names)
ax.legend()

experiment_folder = os.path.join(HITCHHIKING_ROOT_DIR, "results", "image_to_pose_timing")
os.makedirs(experiment_folder, exist_ok=True)
out_p = os.path.join(experiment_folder, "rebutal_timing_network.pdf")
plt.savefig(out_p)

save_pickle(timing_result, os.path.join(experiment_folder, f"seed_{s}_timing_network.npy"))
plt.show()
Loading

0 comments on commit c1d0cb3

Please sign in to comment.