diff --git a/hitchhiking_rotations/cfgs/cfg_pose_to_cube_image.py b/hitchhiking_rotations/cfgs/cfg_pose_to_cube_image.py index 083785b..84a7314 100644 --- a/hitchhiking_rotations/cfgs/cfg_pose_to_cube_image.py +++ b/hitchhiking_rotations/cfgs/cfg_pose_to_cube_image.py @@ -7,7 +7,8 @@ def get_cfg_pose_to_cube_image(device): cfg = { "_target_": "hitchhiking_rotations.utils.Trainer", "lr": 0.01, - "optimizer": "SGD", + "patience": 10, + "optimizer": "Adam", "logger": "${logger}", "verbose": "${verbose}", "device": device, @@ -50,6 +51,7 @@ def get_cfg_pose_to_cube_image(device): "trainers": { "r9_l2": {**cfg, **{"preprocess_input": "${u:flatten}", "model": "${model9}"}}, "r6_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_gramschmidt_f}", "model": "${model6}"}}, + "quat_aug_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_quaternion_aug}", "model": "${model4}"}}, "quat_c_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_quaternion_canonical}", "model": "${model4}"}}, "quat_rf_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_quaternion_rand_flip}", "model": "${model4}"}}, "euler_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_euler}", "model": "${model3}"}}, diff --git a/hitchhiking_rotations/cfgs/cfg_pose_to_fourier.py b/hitchhiking_rotations/cfgs/cfg_pose_to_fourier.py index da5f5af..19b1c2b 100644 --- a/hitchhiking_rotations/cfgs/cfg_pose_to_fourier.py +++ b/hitchhiking_rotations/cfgs/cfg_pose_to_fourier.py @@ -6,7 +6,8 @@ def get_cfg_pose_to_fourier(device, nb, nf): cfg = { "_target_": "hitchhiking_rotations.utils.Trainer", - "lr": 0.01, + "lr": 0.001, + "patience": 10, "optimizer": "Adam", "logger": "${logger}", "verbose": "${verbose}", @@ -40,7 +41,7 @@ def get_cfg_pose_to_fourier(device, nb, nf): "val_data": { "_target_": "hitchhiking_rotations.datasets.PoseToFourierDataset", "mode": "val", - "dataset_size": 400, + "dataset_size": 200, "device": device, "nb": nb, "nf": nf, @@ -56,6 +57,7 @@ def get_cfg_pose_to_fourier(device, nb, nf): "trainers": { "r9_l2": {**cfg, **{"preprocess_input": "${u:flatten}", "model": "${model9}"}}, "r6_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_gramschmidt_f}", "model": "${model6}"}}, + "quat_aug_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_quaternion_aug}", "model": "${model4}"}}, "quat_c_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_quaternion_canonical}", "model": "${model4}"}}, "quat_rf_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_quaternion_rand_flip}", "model": "${model4}"}}, "euler_l2": {**cfg, **{"preprocess_input": "${u:rotmat_to_euler}", "model": "${model3}"}}, diff --git a/hitchhiking_rotations/datasets/.cube_data.swp b/hitchhiking_rotations/datasets/.cube_data.swp new file mode 100644 index 0000000..514522b Binary files /dev/null and b/hitchhiking_rotations/datasets/.cube_data.swp differ diff --git a/hitchhiking_rotations/datasets/fourier_dataset.py b/hitchhiking_rotations/datasets/fourier_dataset.py index 89a9898..a2790e7 100644 --- a/hitchhiking_rotations/datasets/fourier_dataset.py +++ b/hitchhiking_rotations/datasets/fourier_dataset.py @@ -7,6 +7,11 @@ from torch.utils.data import Dataset import matplotlib.pyplot as plt import roma +import jax +import jax.numpy as jnp +import equinox as eqx + +jax.config.update("jax_default_device", jax.devices("cpu")[0]) from hitchhiking_rotations import HITCHHIKING_ROOT_DIR from hitchhiking_rotations.utils import save_pickle, load_pickle @@ -42,25 +47,31 @@ def __getitem__(self, idx): return roma.unitquat_to_rotmat(self.quats[idx]).type(torch.float32), self.features[idx] -class random_fourier_function: - def __init__(self, n_basis, seed, A0=0.0, L=1.0): - np.random.seed(seed) - self.L = L - self.n_basis = n_basis - self.A0 = A0 - self.A = np.random.normal(size=n_basis) - self.B = np.random.normal(size=n_basis) - self.matrix = np.random.normal(size=(1, 9)) +def random_fourier_function(x, nb, seed): + key = jax.random.PRNGKey(seed) + key1, key2 = jax.random.split(key, 2) + A = jax.random.normal(key=key1, shape=(nb,)) + B = jax.random.normal(key=key2, shape=(nb,)) + + model = eqx.nn.MLP(in_size=9, out_size=1, width_size=50, depth=1, key=jax.random.PRNGKey(42 + seed)) + + fFs = 0.0 + input = model(x) + for k in range(len(A)): + fFs += A[k] * jnp.cos((k + 1) * jnp.pi * input) + B[k] * jnp.sin((k + 1) * jnp.pi * input) + return fFs + + +def input_to_fourier(x, seed): + model = eqx.nn.MLP(in_size=9, out_size=1, width_size=50, depth=1, key=jax.random.PRNGKey(42 + seed)) + return model(x) - def __call__(self, x): - fFs = self.A0 / 2 - for k in range(len(self.A)): - fFs = ( - fFs - + self.A[k] * np.cos((k + 1) * np.pi * np.matmul(self.matrix, x) / self.L) - + self.B[k] * np.sin((k + 1) * np.pi * np.matmul(self.matrix, x) / self.L) - ) - return fFs + +def batch_normalize(arr): + mean = np.mean(arr, axis=0, keepdims=True) + std = np.std(arr, axis=0, keepdims=True) + std[std == 0] = 1 + return (arr - mean) / std def create_data(N_points, nb, seed): @@ -69,7 +80,7 @@ def create_data(N_points, nb, seed): Args: N_points: Number of random rotations to generate nb: Number of fourier basis that form the target function - seed: Used to randomly initialize fourier function coefficients + seed: Used to randomly initialize fourier function Returns: rots: Random rotations features: Target function evaluated at rots @@ -77,12 +88,13 @@ def create_data(N_points, nb, seed): np.random.seed(seed) rots = Rotation.random(N_points) inputs = rots.as_matrix().reshape(N_points, -1) - four_func = random_fourier_function(nb, seed) - features = np.apply_along_axis(four_func, 1, inputs) + features = np.array(jax.vmap(random_fourier_function, in_axes=[0, None, None])(inputs, nb, seed).reshape(-1, 1)) + features = batch_normalize(features) return rots.as_quat().astype(np.float32), features.astype(np.float32) def plot_fourier_data(rotations, features): + """Plot distribution of rotations and features.""" import pandas as pd import seaborn as sns @@ -99,5 +111,32 @@ def plot_fourier_data(rotations, features): plt.show() +def plot_fourier_func(nb, seed): + """Plot the target function.""" + rots = Rotation.random(400) + inputs = rots.as_matrix().reshape(400, -1) + four_in = np.array(jax.vmap(input_to_fourier, [0, None])(inputs, seed)) + features = np.array(jax.vmap(random_fourier_function, [0, None, None])(inputs, nb, seed)) + features2 = batch_normalize(features) + sorted_indices = np.argsort(four_in, axis=0) + + plt.figure() + plt.plot(four_in[sorted_indices].flatten(), features[sorted_indices].flatten(), linestyle="-", marker=None) + plt.plot( + four_in[sorted_indices].flatten(), features2[sorted_indices].flatten(), linestyle="-", color="red", marker=None + ) + plt.title(f"nb: {nb}, seed: {seed}") + plt.show() + + if __name__ == "__main__": - create_data(N_points=100, nb=2, seed=5) + # Analyze created data + for b in range(1, 6): + for s in range(0, 1): + # rots, features = create_data(N_points=100, nb=b, seed=s) + # data_stats(rots, features) + # plot_fourier_data(rots, features) + print("MLP PyTree used to create Fourier function inputs:") + model = eqx.nn.MLP(in_size=9, out_size=1, width_size=50, depth=1, key=jax.random.PRNGKey(42)) + eqx.tree_pprint(model) + plot_fourier_func(b, s) diff --git a/hitchhiking_rotations/models/__init__.py b/hitchhiking_rotations/models/__init__.py index 0d0966b..897a366 100644 --- a/hitchhiking_rotations/models/__init__.py +++ b/hitchhiking_rotations/models/__init__.py @@ -3,4 +3,4 @@ # All rights reserved. Licensed under the MIT license. # See LICENSE file in the project root for details. # -from .models import MLP, CNN, MLPNetPCD +from .models import MLP, MLP2, CNN, MLPNetPCD diff --git a/hitchhiking_rotations/models/models.py b/hitchhiking_rotations/models/models.py index c717b8d..9fc901d 100644 --- a/hitchhiking_rotations/models/models.py +++ b/hitchhiking_rotations/models/models.py @@ -22,6 +22,23 @@ def forward(self, x): return self.model(x) +class MLP2(nn.Module): + def __init__(self, input_dim, output_dim): + super(MLP2, self).__init__() + self.model = nn.Sequential( + nn.Linear(input_dim, 200), + nn.ReLU(), + nn.Linear(200, 200), + nn.ReLU(), + nn.Linear(200, 200), + nn.ReLU(), + nn.Linear(200, output_dim), + ) + + def forward(self, x): + return self.model(x) + + class CNN(nn.Module): def __init__(self, input_dim, width, height): super(CNN, self).__init__() diff --git a/hitchhiking_rotations/utils/conversions.py b/hitchhiking_rotations/utils/conversions.py index 7289b1b..1331e62 100644 --- a/hitchhiking_rotations/utils/conversions.py +++ b/hitchhiking_rotations/utils/conversions.py @@ -3,27 +3,28 @@ # All rights reserved. Licensed under the MIT license. # See LICENSE file in the project root for details. # -from .euler_helper import euler_angles_to_matrix, matrix_to_euler_angles +from hitchhiking_rotations.utils.euler_helper import euler_angles_to_matrix, matrix_to_euler_angles import roma import torch +from math import pi -def euler_to_rotmat(inp: torch.Tensor) -> torch.Tensor: +def euler_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor: return euler_angles_to_matrix(inp.reshape(-1, 3), convention="XZY") -def quaternion_to_rotmat(inp: torch.Tensor) -> torch.Tensor: +def quaternion_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor: # without normalization # normalize first x = inp.reshape(-1, 4) return roma.unitquat_to_rotmat(x / x.norm(dim=1, keepdim=True)) -def gramschmidt_to_rotmat(inp: torch.Tensor) -> torch.Tensor: +def gramschmidt_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor: return roma.special_gramschmidt(inp.reshape(-1, 3, 2)) -def symmetric_orthogonalization(x): +def symmetric_orthogonalization(x, **kwargs): """Maps 9D input vectors onto SO(3) via symmetric orthogonalization. x: should have size [batch_size, 9] @@ -40,65 +41,80 @@ def symmetric_orthogonalization(x): return r -def procrustes_to_rotmat(inp: torch.Tensor) -> torch.Tensor: +def procrustes_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor: return symmetric_orthogonalization(inp) return roma.special_procrustes(inp.reshape(-1, 3, 3)) -def rotvec_to_rotmat(inp: torch.Tensor) -> torch.Tensor: +def rotvec_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor: return roma.rotvec_to_rotmat(inp.reshape(-1, 3)) # rotmat to x / maybe here reshape is missing -def rotmat_to_euler(base: torch.Tensor) -> torch.Tensor: +def rotmat_to_euler(base: torch.Tensor, **kwargs) -> torch.Tensor: return matrix_to_euler_angles(base, convention="XZY") -def rotmat_to_quaternion(base: torch.Tensor) -> torch.Tensor: +def rotmat_to_quaternion(base: torch.Tensor, **kwargs) -> torch.Tensor: return roma.rotmat_to_unitquat(base) -def rotmat_to_quaternion_rand_flip(base: torch.Tensor) -> torch.Tensor: - # we could duplicate the data and flip the quaternions on both sides - # def quat_aug_dataset(quats: np.ndarray, ixs): - # # quats: (N, M, .., 4) - # # return augmented inputs and quats - # return (np.concatenate((quats, -quats), axis=0), *np.concatenate((ixs, ixs), axis=0)) - +def rotmat_to_quaternion_rand_flip(base: torch.Tensor, **kwargs) -> torch.Tensor: rep = roma.rotmat_to_unitquat(base) rand_flipping = torch.rand(base.shape[0]) > 0.5 rep[rand_flipping] *= -1 return rep -def rotmat_to_quaternion_canonical(base: torch.Tensor) -> torch.Tensor: +def rotmat_to_quaternion_canonical(base: torch.Tensor, **kwargs) -> torch.Tensor: rep = roma.rotmat_to_unitquat(base) rep[rep[:, 3] < 0] *= -1 return rep -def rotmat_to_gramschmidt(base: torch.Tensor) -> torch.Tensor: +def rotmat_to_quaternion_aug(base: torch.Tensor, mode: str) -> torch.Tensor: + """Performs memory-efficient quaternion augmentation by randomly + selecting half of the quaternions in the batch with scalar part + smaller than 0.1 and then multiplies them by -1. + """ + rep = rotmat_to_quaternion_canonical(base) + + if mode == "train": + rep[(torch.rand(rep.size(0), device=rep.device) < 0.5) * (rep[:, 3] < 0.1)] *= -1 + + return rep + + +def rotmat_to_gramschmidt(base: torch.Tensor, **kwargs) -> torch.Tensor: return base[:, :, :2] -def rotmat_to_gramschmidt_f(base: torch.Tensor) -> torch.Tensor: +def rotmat_to_gramschmidt_f(base: torch.Tensor, **kwargs) -> torch.Tensor: return base[:, :, :2].reshape(-1, 6) -def rotmat_to_procrustes(base: torch.Tensor) -> torch.Tensor: +def rotmat_to_procrustes(base: torch.Tensor, **kwargs) -> torch.Tensor: return base -def rotmat_to_rotvec(base: torch.Tensor) -> torch.Tensor: +def rotmat_to_rotvec(base: torch.Tensor, **kwargs) -> torch.Tensor: return roma.rotmat_to_rotvec(base) +def rotmat_to_rotvec_canonical(base: torch.Tensor, **kwargs) -> torch.Tensor: + """WARNING: THIS FUNCTION HAS NOT BEEN TESTED""" + rep = roma.rotmat_to_rotvec(base) + rep[rep[:, 2] < 0] = (1.0 - 2.0 * pi / rep[rep[:, 2] < 0].norm(dim=1, keepdim=True)) * rep[rep[:, 2] < 0] + return rep + + def test_all(): from scipy.spatial.transform import Rotation from torch import from_numpy as tr import numpy as np + from torch import from_numpy as tr rs = Rotation.random(1000) euler = rs.as_euler("XZY", degrees=False) diff --git a/hitchhiking_rotations/utils/helper.py b/hitchhiking_rotations/utils/helper.py index 7fb8bc1..34491aa 100644 --- a/hitchhiking_rotations/utils/helper.py +++ b/hitchhiking_rotations/utils/helper.py @@ -3,15 +3,15 @@ # All rights reserved. Licensed under the MIT license. # See LICENSE file in the project root for details. # -def passthrough(*x): +def passthrough(*x, **kwargs): if len(x) == 1: return x[0] return x -def flatten(x): +def flatten(x, **kwargs): return x.reshape(x.shape[0], -1) -def n_3x3(x): +def n_3x3(x, **kwargs): return x.reshape(-1, 3, 3) diff --git a/hitchhiking_rotations/utils/notation.py b/hitchhiking_rotations/utils/notation.py index 044b875..98e4189 100644 --- a/hitchhiking_rotations/utils/notation.py +++ b/hitchhiking_rotations/utils/notation.py @@ -7,15 +7,16 @@ class RotRep(Enum): - GSO = "$\mathbb{R}^6$+GSO" - SVD = "$\mathbb{R}^9$+SVD" - QUAT_C = "Quat$^+$" - QUAT = "Quat" - QUAT_RF = "Quat+RF" - EULER = "Euler" - EXP = "Exp" - ROTMAT = "$\mathbb{R}^9$" - RSIX = "$\mathbb{R}^6$" + GSO = r"$\mathbb{R}^6$+GSO" + SVD = r"$\mathbb{R}^9$+SVD" + QUAT_C = r"Quat$^+$" + QUAT = r"Quat" + QUAT_RF = r"Quat$^{\mathrm{RF}}$" + QUAT_AUG = r"Quat$^{\mathrm{a}{+}}$" + EULER = r"Euler" + EXP = r"Exp" + ROTMAT = r"$\mathbb{R}^9$" + RSIX = r"$\mathbb{R}^6$" def __str__(self): return "%s" % self.value diff --git a/hitchhiking_rotations/utils/trainer.py b/hitchhiking_rotations/utils/trainer.py index f44e1b1..1b2c54c 100644 --- a/hitchhiking_rotations/utils/trainer.py +++ b/hitchhiking_rotations/utils/trainer.py @@ -39,6 +39,7 @@ def __init__( loss, model, lr, + patience, optimizer, logger, verbose, @@ -68,7 +69,7 @@ def __init__( self.nr_training_steps = 0 self.nr_test_steps = 0 - self.early_stopper = EarlyStopper(model=self.model, patience=10, min_delta=0) + self.early_stopper = EarlyStopper(model=self.model, patience=patience, min_delta=0) def train_batch(self, x, target, epoch): self.model.train() @@ -77,7 +78,7 @@ def train_batch(self, x, target, epoch): with torch.no_grad(): pp_target = self.preprocess_target(target) - x = self.preprocess_input(x) + x = self.preprocess_input(x, mode="train") pred = self.model(x) pred_loss = self.postprocess_pred_loss(pred) @@ -95,7 +96,7 @@ def train_batch(self, x, target, epoch): @torch.no_grad() def test_batch(self, x, target, epoch, mode): self.model.eval() - x = self.preprocess_input(x) + x = self.preprocess_input(x, mode="test") pred = self.model(x) pred_loss = self.postprocess_pred_loss(pred) pp_target = self.preprocess_target(target) diff --git a/visu/figure_12a.py b/visu/figure_12a.py index 91cda28..f93cbab 100644 --- a/visu/figure_12a.py +++ b/visu/figure_12a.py @@ -43,22 +43,21 @@ df = pd.DataFrame.from_dict(df_res) mapping = { - "r9_svd": str(RotRep.SVD) + "-Chordal", - "r6_gso": str(RotRep.GSO) + "-Chordal", - "quat_c": str(RotRep.QUAT_C) + "-Chordal", - "rotvec": str(RotRep.EXP) + "-Chordal", - "euler": str(RotRep.EULER) + "-Chordal", + "r9_svd": RotRep.SVD, + "r6_gso": RotRep.GSO, + "quat_c": RotRep.QUAT_C, + # "quat_rf": RotRep.QUAT_RF, + "rotvec": RotRep.EXP, + "euler": RotRep.EULER, } -for k, v in mapping.items(): - df["method"][df["method"] == k + "_" + training_metric] = v - +df["method"] = df["method"].replace({k + "_" + training_metric: v for k, v in mapping.items()}) df["method"] = pd.Categorical(df["method"], categories=[v for v in mapping.values()], ordered=True) plt.style.use(os.path.join(HITCHHIKING_ROOT_DIR, "assets", "prettyplots.mplstyle")) sns.set_style("whitegrid") plt.rcParams.update({"font.size": 11}) -plt.figure(figsize=(7, 2.5)) +plt.figure(figsize=(5, 2.5)) plt.subplot(1, 1, 1) @@ -73,10 +72,18 @@ fliersize=2.5, showfliers=True, ) -plt.xlabel("Error - Geodesic") + +if selected_metric == "geodesic_distance": + plt.xlabel("Geodesic distance") + plt.ylabel("") -plt.tight_layout() +# plt.xscale("log") +print("WARNING: Tick labels are hardcoded!") +plt.xticks([0.3, 0.4, 0.5, 0.6], ["0.3", "0.4", "0.5", "0.6"]) + +plt.tight_layout() out_p = os.path.join(HITCHHIKING_ROOT_DIR, "results", exp, "figure_12a.pdf") + plt.savefig(out_p) plt.show() diff --git a/visu/figure_12b.py b/visu/figure_12b.py index 10e1577..6c3912f 100644 --- a/visu/figure_12b.py +++ b/visu/figure_12b.py @@ -7,7 +7,6 @@ import pandas as pd from hitchhiking_rotations.utils import RotRep - files = [str(s) for s in Path(os.path.join(HITCHHIKING_ROOT_DIR, "results", "pose_to_cube_image")).rglob("*result.npy")] results = [np.load(file, allow_pickle=True) for file in files] @@ -47,30 +46,33 @@ mapping = { "r9": RotRep.ROTMAT, "r6": RotRep.RSIX, + "quat_aug": RotRep.QUAT_AUG, "quat_c": RotRep.QUAT_C, - "quat_rf": str(RotRep.QUAT) + "_rf", + "quat_rf": RotRep.QUAT_RF, "rotvec": RotRep.EXP, "euler": RotRep.EULER, } training_metric = "l2" - for k, v in mapping.items(): - df["method"][df["method"] == k + "_" + training_metric] = v - + df["method"] = df["method"].replace({k + "_" + training_metric: v for k, v in mapping.items()}) df["method"] = pd.Categorical(df["method"], categories=[v for v in mapping.values()], ordered=True) plt.style.use(os.path.join(HITCHHIKING_ROOT_DIR, "assets", "prettyplots.mplstyle")) sns.set_style("whitegrid") plt.rcParams.update({"font.size": 11}) -plt.figure(figsize=(7, 2.5)) +plt.figure(figsize=(5, 2.5)) plt.subplot(1, 1, 1) - sns.boxplot(data=df, x="score", y="method", palette="Greens", orient="h", width=0.5, linewidth=1.5, fliersize=2.5) -plt.xlabel("Error - MSE") + +plt.xlabel("MSE") plt.ylabel("") -plt.tight_layout() +plt.xscale("log") + +print("WARNING: Tick labels are hardcoded!") +plt.xticks([0.0005, 0.001, 0.002, 0.004], [r"$5\cdot10^{-4}$", r"$10^{-3}$", r"$2\cdot10^{-3}$", r"$4\cdot10^{-3}$"]) +plt.tight_layout() out_p = os.path.join(HITCHHIKING_ROOT_DIR, "results", "pose_to_cube_image", "figure_12b.pdf") plt.savefig(out_p) plt.show() diff --git a/visu/figure_14.py b/visu/figure_14.py index a64eff3..0782f6e 100644 --- a/visu/figure_14.py +++ b/visu/figure_14.py @@ -57,12 +57,13 @@ if rename_and_filter: mapping = { - "r9": RotRep.SVD, - "r6": RotRep.GSO, - "quat_c": RotRep.QUAT_C, - "quat_rf": str(RotRep.QUAT) + "_rf", - "rotvec": RotRep.EXP, "euler": RotRep.EULER, + "rotvec": RotRep.EXP, + "quat_rf": RotRep.QUAT_RF, + "quat_c": RotRep.QUAT_C, + "quat_aug": RotRep.QUAT_AUG, + "r6": RotRep.RSIX, + "r9": RotRep.ROTMAT, } training_metric = "l2" @@ -73,7 +74,7 @@ sns.set_style("whitegrid") plt.rcParams.update({"font.size": 11}) -plt.figure(figsize=(5.5, 1.0)) +plt.figure(figsize=(5.5, 1)) g = sns.catplot( data=df, x="basis", @@ -86,13 +87,18 @@ aspect=2.0, ) -sns.move_legend(g, "upper left", bbox_to_anchor=(0.11, 0.98), ncol=3, title="Network input") # len(names) +# g.map(sns.stripplot, "basis", "score", "method", dodge=True, alpha=0.6) + +# sns.move_legend(g, "upper left", bbox_to_anchor=(0.11, 0.98), ncol=2) #, title="Network input") +sns.move_legend( + g, "upper left", bbox_to_anchor=(0.0555, 0.99), ncol=7, handletextpad=0.2, columnspacing=0.55, title=None +) for i in range(nb_max - 1): plt.axvline(0.5 + i, color="lightgrey", dashes=(2, 2)) -plt.xlabel(f"Error - {selected_metric}") -plt.ylabel("") +g.set(xlabel=r"Number of fourier basis functions $n_b$", ylabel="MSE") + plt.yscale("log") plt.tight_layout()