-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of github.com:martius-lab/hitchhiking-rotations i…
…nto main
- Loading branch information
Showing
10 changed files
with
478 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
gigachad_colors = [ | ||
(0.368, 0.507, 0.71), | ||
(0.881, 0.611, 0.142), | ||
(0.923, 0.386, 0.209), | ||
(0.56, 0.692, 0.195), | ||
(0.528, 0.471, 0.701), | ||
(0.772, 0.432, 0.102), | ||
(0.572, 0.586, 0.0), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
,method,score,metric | ||
0,$\mathbb{R}^9$+SVD,91.04188604844259,AUC ADD-S | ||
1,$\mathbb{R}^9$+SVD,95.07767991069852,<2cm | ||
2,$\mathbb{R}^9$+SVD,91.3637847207741,AUC ADD-S | ||
3,$\mathbb{R}^9$+SVD,94.59201274202522,<2cm | ||
4,$\mathbb{R}^9$+SVD,91.08496590464102,AUC ADD-S | ||
5,$\mathbb{R}^9$+SVD,94.61647270550057,<2cm | ||
6,$\mathbb{R}^6$+GSO,91.3470127729485,AUC ADD-S | ||
7,$\mathbb{R}^6$+GSO,94.19720889771384,<2cm | ||
8,$\mathbb{R}^6$+GSO,91.42867218065841,AUC ADD-S | ||
9,$\mathbb{R}^6$+GSO,94.3550012793383,<2cm | ||
10,$\mathbb{R}^6$+GSO,90.77961308438314,AUC ADD-S | ||
11,$\mathbb{R}^6$+GSO,95.43328935233129,<2cm | ||
12,Quat$^+$,91.3210065650409,AUC ADD-S | ||
13,Quat$^+$,95.46569600702125,<2cm | ||
14,Quat$^+$,90.9166210514394,AUC ADD-S | ||
15,Quat$^+$,94.20788355159773,<2cm | ||
16,Quat$^+$,91.32054850636075,AUC ADD-S | ||
17,Quat$^+$,94.56954022581728,<2cm | ||
18,Euler,90.77298990957122,AUC ADD-S | ||
19,Euler,92.94620768691362,<2cm | ||
20,Euler,90.63536905144662,AUC ADD-S | ||
21,Euler,93.23661241181223,<2cm | ||
22,Euler,90.82663296670404,AUC ADD-S | ||
23,Euler,93.52104289463712,<2cm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from pytictac import Timer, CpuTimer | ||
import torch | ||
|
||
|
||
def test_svd(): | ||
BS = 128 | ||
repeats = 100 | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
if torch.cuda.is_available(): | ||
tim = Timer | ||
else: | ||
tim = CpuTimer | ||
|
||
for i in range(repeats): | ||
m = torch.rand((BS, 3, 3), device=device) | ||
u, s, v = torch.svd(m) | ||
|
||
with tim("SVD"): | ||
for i in range(repeats): | ||
m = torch.rand((BS, 3, 3), device=device) | ||
u, s, v = torch.svd(m) | ||
|
||
m = torch.rand((BS, 3, 3), device=device) | ||
with tim("SVD"): | ||
for i in range(repeats): | ||
u, s, v = torch.svd(m) | ||
|
||
m = torch.rand((BS, 3, 3), device=device) | ||
with tim("SVD single"): | ||
u, s, v = torch.svd(m) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_svd() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import os | ||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
import numpy as np | ||
from hitchhiking_rotations import HITCHHIKING_ROOT_DIR | ||
from matplotlib.colors import LinearSegmentedColormap | ||
import seaborn as sns | ||
|
||
|
||
def generate_colormap(base_color, lvl=0.2, alpha=0.5): | ||
# Convert base color to RGB values between 0 and 1 | ||
base_color = np.array(base_color) / 255.0 | ||
|
||
# Define the color at the extremes (0 and 1) | ||
lighter_color = base_color + lvl | ||
darker_color = base_color - lvl | ||
|
||
# Ensure colors are within valid range | ||
lighter_color = np.clip(lighter_color, 0, 1) | ||
darker_color = np.clip(darker_color, 0, 1) | ||
|
||
# Generate colormap | ||
colors = [lighter_color, base_color, darker_color] | ||
positions = [0, alpha, 1] | ||
return LinearSegmentedColormap.from_list("custom_cmap", list(zip(positions, colors))) | ||
|
||
|
||
o = generate_colormap([224, 157, 52, 255]) | ||
b = generate_colormap([57, 84, 122, 255]) | ||
|
||
|
||
# Load data | ||
df = pd.read_csv(os.path.join(HITCHHIKING_ROOT_DIR, "results", "dense_fusion", "dense_fusion_experiment.csv")) | ||
|
||
plt.style.use(os.path.join(HITCHHIKING_ROOT_DIR, "assets", "prettyplots.mplstyle")) | ||
sns.set_style("whitegrid") | ||
|
||
|
||
# Define symbols for each method | ||
method_symbols = { | ||
"$\mathbb{R}^9$+SVD": "D", # square | ||
"$\mathbb{R}^6$+GSO": "h", # diamond | ||
"Quat$^+$": "*", # circle | ||
"Euler": "X", # cross | ||
} | ||
|
||
# Define colors for each metric | ||
metric_colors = { | ||
"AUC ADD-S": o, | ||
"<2cm": b, | ||
} | ||
|
||
# Set up colormap for gradient based on scores | ||
# Get unique methods and assign x-values to them | ||
unique_methods = df["method"].unique() | ||
method_indices = {method: idx for idx, method in enumerate(unique_methods)} | ||
|
||
i = 0 | ||
# Plotting | ||
fig = plt.figure(figsize=(5, 3)) | ||
for method, symbol in method_symbols.items(): | ||
for metric, color in metric_colors.items(): | ||
sub_df = df[(df["method"] == method) & (df["metric"] == metric)] | ||
scores = sub_df["score"] | ||
|
||
min_v = df[(df["metric"] == metric)]["score"].min() | ||
max_v = df[(df["metric"] == metric)]["score"].max() | ||
# Normalize scores for gradient colormap | ||
normalized_scores = (scores - min_v.min()) / (max_v.max() - min_v.min()) | ||
x_values = [method_indices[method]] * len(sub_df.index) | ||
x_values += np.random.uniform(-0.15, 0.15, len(x_values)) # Add jitter to x-values | ||
plt.scatter(x_values, scores, c=color(normalized_scores), marker=symbol, edgecolor="black", linewidth=0.5, s=70) | ||
|
||
# Markers for legend - put them on negative y-axis | ||
if i == 0: | ||
plt.scatter( | ||
x_values, | ||
scores * -1, | ||
c=o([0.5] * len(scores)), | ||
marker="s", | ||
label="AUC ADD-S", | ||
edgecolor="black", | ||
linewidth=0.5, | ||
s=70, | ||
) | ||
plt.scatter( | ||
x_values, | ||
scores * -1, | ||
c=b([0.5] * len(scores)), | ||
marker="s", | ||
label="<2cm", | ||
edgecolor="black", | ||
linewidth=0.5, | ||
s=70, | ||
) | ||
i = 1 | ||
|
||
# Limity y-axis to not see the markers on negative side | ||
fig.axes[0].set_ylim([90.2, 95.8]) | ||
|
||
plt.legend( | ||
title="", | ||
bbox_to_anchor=(0.5, 1.15), | ||
loc="upper center", | ||
ncol=len(method_symbols), | ||
frameon=False, | ||
borderaxespad=0.0, | ||
handletextpad=0.5, | ||
markerscale=1.0, | ||
) | ||
plt.xticks(range(len(unique_methods)), unique_methods) # Set ticks to method names | ||
plt.ylabel("Score") | ||
plt.grid(True, linestyle="--", color="gray", alpha=0.5) | ||
|
||
plt.tight_layout() | ||
plt.savefig(os.path.join(HITCHHIKING_ROOT_DIR, "results", "dense_fusion", "figure_13.pdf")) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from hitchhiking_rotations import HITCHHIKING_ROOT_DIR | ||
from hitchhiking_rotations.utils import save_pickle | ||
from hitchhiking_rotations.utils import RotRep, gigachad_colors | ||
from hitchhiking_rotations.cfgs import get_cfg_cube_image_to_pose | ||
import numpy as np | ||
import argparse | ||
import os | ||
import hydra | ||
from omegaconf import OmegaConf | ||
import torch | ||
from torch.utils.data import DataLoader | ||
import matplotlib.pyplot as plt | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--seed", | ||
type=int, | ||
default=0, | ||
help="Random seed used during training, " + "for pose_to_fourier the seed is used to select the target function.", | ||
) | ||
args = parser.parse_args() | ||
|
||
s = args.seed | ||
torch.manual_seed(s) | ||
np.random.seed(s) | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
torch.cuda.empty_cache() | ||
torch.zeros((1,), device=device) # Initialize CUDA | ||
|
||
cfg_exp = get_cfg_cube_image_to_pose(device) | ||
OmegaConf.register_new_resolver("u", lambda x: hydra.utils.get_method("hitchhiking_rotations.utils." + x)) | ||
cfg_exp = OmegaConf.create(cfg_exp) | ||
|
||
trainers = hydra.utils.instantiate(cfg_exp.trainers) | ||
test_data = hydra.utils.instantiate(cfg_exp.test_data) | ||
# Create dataloaders | ||
epoch = 0 | ||
|
||
timing_result = {} | ||
batch_sizes = [1, 32, 256, 1024] | ||
for batch_size in batch_sizes: | ||
test_dataloader = DataLoader(test_data, num_workers=0, batch_size=batch_size, shuffle=True) | ||
# Perform testing | ||
for j, batch in enumerate(test_dataloader): | ||
x, target = batch | ||
for trainer_name, pretty_name in zip( | ||
["r9_svd_geodesic_distance", "r9_geodesic_distance", "r6_gso_geodesic_distance"], | ||
[str(s) for s in [RotRep.SVD, RotRep.ROTMAT, RotRep.GSO]], | ||
): | ||
trainer = trainers[trainer_name] | ||
|
||
for i in range(100): | ||
trainer.test_batch(x.clone(), target.clone(), epoch, mode="test") | ||
|
||
timing_result[f"{pretty_name} \n BS-{batch_size}"] = [] | ||
for i in range(100): | ||
rand = torch.rand_like(x) * 0.00001 | ||
res, _ = trainer.test_batch_time(x + rand, target, epoch, mode="test") | ||
timing_result[f"{pretty_name} \n BS-{batch_size}"].append(res) | ||
|
||
for k in timing_result.keys(): | ||
times = np.array(timing_result[k]) | ||
print(times.shape) | ||
print(k, " mean t0-t5 ", times.mean(axis=0)) | ||
print(k, " std t0-t5 ", times.std(axis=0)) | ||
|
||
break | ||
|
||
|
||
# Chat GPT visualization | ||
|
||
# Extract timing data for each method | ||
method_names = list(timing_result.keys()) | ||
sub_timings = np.array([timing_result[k] for k in method_names]) | ||
|
||
# Calculate means for each subtiming | ||
means = sub_timings.mean(axis=1) | ||
|
||
# Create stacked bar plot | ||
fig, ax = plt.subplots(figsize=(12, 6)) | ||
|
||
bar_width = 0.8 | ||
|
||
index = [] | ||
c = 0 | ||
for b in range(len(batch_sizes)): | ||
for i in range(len(method_names) // len(batch_sizes)): | ||
index.append(c) | ||
c += 1 | ||
c += 0.5 | ||
index = np.array(index) | ||
|
||
plots = [] | ||
bottom = np.zeros(len(method_names)) | ||
|
||
subtiming_labels = [ | ||
"preprocess_input", | ||
"model_forward", | ||
"postprocess_pred_loss", | ||
"preprocess_target", | ||
"loss", | ||
"postprocess_pred_logging", | ||
] | ||
|
||
for i, label in enumerate(subtiming_labels): | ||
plot = ax.bar(index, means[:, i], bar_width, bottom=bottom, label=label, color=gigachad_colors[i]) | ||
bottom += means[:, i] | ||
plots.append(plot) | ||
|
||
ax.set_xlabel("Methods") | ||
ax.set_ylabel("Time in ms") | ||
ax.set_title("Timing Results") | ||
ax.set_xticks(index) | ||
ax.set_xticklabels(method_names) | ||
ax.legend() | ||
|
||
experiment_folder = os.path.join(HITCHHIKING_ROOT_DIR, "results", "image_to_pose_timing") | ||
os.makedirs(experiment_folder, exist_ok=True) | ||
out_p = os.path.join(experiment_folder, "rebutal_timing_network.pdf") | ||
plt.savefig(out_p) | ||
|
||
save_pickle(timing_result, os.path.join(experiment_folder, f"seed_{s}_timing_network.npy")) | ||
plt.show() |
Oops, something went wrong.