diff --git a/wassersteinwormhole/DefaultConfig.py b/wassersteinwormhole/DefaultConfig.py index 1a9a3fe..519cd32 100644 --- a/wassersteinwormhole/DefaultConfig.py +++ b/wassersteinwormhole/DefaultConfig.py @@ -2,7 +2,7 @@ from flax import linen as nn import jax.numpy as jnp -from typing import Callable, Any, Optional +from typing import Callable, Any @struct.dataclass class DefaultConfig: diff --git a/wassersteinwormhole/Wormhole.py b/wassersteinwormhole/Wormhole.py index b94b8a6..54c6a82 100644 --- a/wassersteinwormhole/Wormhole.py +++ b/wassersteinwormhole/Wormhole.py @@ -1,102 +1,136 @@ - -import optax -from flax import linen as nn -from flax import struct -from flax.training import train_state +from functools import partial import jax import jax.numpy as jnp import jax.scipy as jsp -from jax import random, grad, jit, vmap -from functools import partial -import scipy.stats import numpy as np +import optax +import scipy.stats +from flax import linen as nn +from jax import jit, random from tqdm import trange -from wassersteinwormhole._utils_Transformer import * import wassersteinwormhole.utils_OT as utils_OT - +from wassersteinwormhole._utils_Transformer import Metrics, TrainState, Transformer from wassersteinwormhole.DefaultConfig import DefaultConfig - - + + def MaxMinScale(arr): - """ :meta private: """ - - min_arr = arr.min(axis = 0) - max_arr = arr.max(axis = 0) - - arr = 2*(arr - arr.min(axis = 0, keepdims = True))/(arr.max(axis = 0, keepdims = True) - arr.min(axis = 0, keepdims = True))-1 - return(arr) - -def pad_pointclouds(point_clouds, weights, max_shape = -1): + + arr = ( + 2 + * (arr - arr.min(axis=0, keepdims=True)) + / (arr.max(axis=0, keepdims=True) - arr.min(axis=0, keepdims=True)) + - 1 + ) + return arr + + +def pad_pointclouds(point_clouds, weights, max_shape=-1): """ :meta private: """ - - if(max_shape == -1): - max_shape = np.max([pc.shape[0] for pc in point_clouds])+1 + + if max_shape == -1: + max_shape = np.max([pc.shape[0] for pc in point_clouds]) + 1 else: max_shape = max_shape + 1 - weights_pad = np.asarray([np.concatenate((weight, np.zeros(max_shape - pc.shape[0])), axis = 0) for pc, weight in zip(point_clouds, weights)]) - point_clouds_pad = np.asarray([np.concatenate([pc, np.zeros([max_shape - pc.shape[0], pc.shape[-1]])], axis = 0) for pc in point_clouds]) - - weights_pad = weights_pad/weights_pad.sum(axis = 1, keepdims = True) + weights_pad = np.asarray( + [ + np.concatenate((weight, np.zeros(max_shape - pc.shape[0])), axis=0) + for pc, weight in zip(point_clouds, weights) + ] + ) + point_clouds_pad = np.asarray( + [ + np.concatenate( + [pc, np.zeros([max_shape - pc.shape[0], pc.shape[-1]])], axis=0 + ) + for pc in point_clouds + ] + ) + + weights_pad = weights_pad / weights_pad.sum(axis=1, keepdims=True) + + return ( + point_clouds_pad[:, :-1].astype("float32"), + weights_pad[:, :-1].astype("float32"), + ) + + +class Wormhole: + """ + Initializes Wormhole model and processes input point clouds - return(point_clouds_pad[:, :-1].astype('float32'), weights_pad[:, :-1].astype('float32')) -class Wormhole(): - - """ - Initializes Wormhole model and processes input point clouds - - :param point_clouds: (list of np.array) list of train-set point clouds to train Wormhole on :param weights: (list of np.array) list of per point weight for each train-set point cloud (default None, indicating uniform weights) :param point_clouds_test: (list of np.array) list of test-set point clouds (default None) :param weights_test: (list of np.array) list of per point weight for each test-set point cloud (default None, indicating uniform weights) - :param config: (flax struct.dataclass) object with parameters for Wormhole such as OT metric choice, emedding dimention, etc. See docs for 'DefaultConfig.py' and tutorial details. - + :param config: (flax struct.dataclass) object with parameters for Wormhole such as OT metric choice, emedding dimention, etc. See docs for 'DefaultConfig.py' and tutorial details. + :return: initialized Wormhole model - """ - - def __init__(self, point_clouds, weights = None, point_clouds_test = None, weights_test = None, config = DefaultConfig): - + """ + + def __init__( + self, + point_clouds, + weights=None, + point_clouds_test=None, + weights_test=None, + config=DefaultConfig, + ): + self.config = config self.point_clouds = point_clouds - - if(weights is None): - self.weights = [np.ones(pc.shape[0])/pc.shape[0] for pc in self.point_clouds] + + if weights is None: + self.weights = [ + np.ones(pc.shape[0]) / pc.shape[0] for pc in self.point_clouds + ] else: self.weights = weights - - if(point_clouds_test is None): - self.point_clouds, self.weights = pad_pointclouds(self.point_clouds, self.weights) + + if point_clouds_test is None: + self.point_clouds, self.weights = pad_pointclouds( + self.point_clouds, self.weights + ) else: self.point_clouds_test = point_clouds_test - - if(weights_test is None): - self.weights_test = [np.ones(pc.shape[0])/pc.shape[0] for pc in self.point_clouds_test] + + if weights_test is None: + self.weights_test = [ + np.ones(pc.shape[0]) / pc.shape[0] for pc in self.point_clouds_test + ] else: self.weights_test = weights_test - - - total_point_clouds, total_weights = pad_pointclouds(list(self.point_clouds) + list(self.point_clouds_test), list(self.weights) + list(self.weights_test)) - self.point_clouds, self.weights = total_point_clouds[:len(list(self.point_clouds))], total_weights[:len(list(self.point_clouds))] - self.point_clouds_test, self.weights_test = total_point_clouds[len(list(self.point_clouds)):], total_weights[len(list(self.point_clouds)):] - - self.scale_weights = np.exp(-jsp.special.xlogy(self.weights, self.weights).sum(axis = 1).mean()) - self.out_seq_len = int(jnp.exp(-jsp.special.xlogy(self.weights, self.weights).sum(axis = 1).mean())) + total_point_clouds, total_weights = pad_pointclouds( + list(self.point_clouds) + list(self.point_clouds_test), + list(self.weights) + list(self.weights_test), + ) + self.point_clouds, self.weights = ( + total_point_clouds[: len(list(self.point_clouds))], + total_weights[: len(list(self.point_clouds))], + ) + self.point_clouds_test, self.weights_test = ( + total_point_clouds[len(list(self.point_clouds)) :], + total_weights[len(list(self.point_clouds)) :], + ) + + self.scale_weights = np.exp( + -jsp.special.xlogy(self.weights, self.weights).sum(axis=1) + ).mean() + self.out_seq_len = int( + jnp.exp(-jsp.special.xlogy(self.weights, self.weights).sum(axis=1)).mean() + ) self.inp_dim = self.point_clouds.shape[-1] - - - self.eps_enc = config.eps_enc self.eps_dec = config.eps_dec @@ -104,80 +138,109 @@ def __init__(self, point_clouds, weights = None, point_clouds_test = None, weigh self.lse_dec = config.lse_dec self.coeff_dec = config.coeff_dec - + self.dist_func_enc = config.dist_func_enc self.dist_func_dec = config.dist_func_dec - - self.jit_dist_enc = jax.jit(jax.vmap(getattr(utils_OT, self.dist_func_enc), (0, 0, None, None), 0), static_argnums=[2,3]) - self.jit_dist_dec = jax.jit(jax.vmap(getattr(utils_OT, self.dist_func_dec), (0, 0, None, None), 0), static_argnums=[2,3]) - - if(self.coeff_dec < 0): - self.jit_dist_dec = jax.jit(jax.vmap(utils_OT.Zeros, (0, 0, None, None), 0), static_argnums=[2,3]) - self.coeff_dec = 0.0 + + self.jit_dist_enc = jax.jit( + jax.vmap(getattr(utils_OT, self.dist_func_enc), (0, 0, None, None), 0), + static_argnums=[2, 3], + ) + self.jit_dist_dec = jax.jit( + jax.vmap(getattr(utils_OT, self.dist_func_dec), (0, 0, None, None), 0), + static_argnums=[2, 3], + ) + + if self.coeff_dec < 0: + self.jit_dist_dec = jax.jit( + jax.vmap(utils_OT.Zeros, (0, 0, None, None), 0), static_argnums=[2, 3] + ) + self.coeff_dec = 0.0 self.scale = config.scale self.factor = config.factor self.point_clouds = self.scale_func(self.point_clouds) * self.factor - if(point_clouds_test is not None): - self.point_clouds_test = self.scale_func(self.point_clouds_test)*self.factor - - - self.pc_max_val = np.max(self.point_clouds[self.weights > 0]) #* (1 + 1 * np.isin(self.dist_func_dec, ['GS', 'GW'])) - self.pc_min_val = np.min(self.point_clouds[self.weights > 0]) #* (1 + 1 * np.isin(self.dist_func_dec, ['GS', 'GW'])) - self.scale_out = True #not np.isin(self.dist_func_dec, ['GS', 'GW']) - - self.model = Transformer(self.config, out_seq_len = self.out_seq_len, inp_dim = self.inp_dim, - scale_weights = self.scale_weights, scale_out = self.scale_out, min_val = self.pc_min_val, max_val = self.pc_max_val) - + if point_clouds_test is not None: + self.point_clouds_test = ( + self.scale_func(self.point_clouds_test) * self.factor + ) + + self.pc_max_val = np.max( + self.point_clouds[self.weights > 0] + ) # * (1 + 1 * np.isin(self.dist_func_dec, ['GS', 'GW'])) + self.pc_min_val = np.min( + self.point_clouds[self.weights > 0] + ) # * (1 + 1 * np.isin(self.dist_func_dec, ['GS', 'GW'])) + self.scale_out = True # not np.isin(self.dist_func_dec, ['GS', 'GW']) + + self.model = Transformer( + self.config, + out_seq_len=self.out_seq_len, + inp_dim=self.inp_dim, + scale_weights=self.scale_weights, + scale_out=self.scale_out, + min_val=self.pc_min_val, + max_val=self.pc_max_val, + ) def scale_func(self, point_clouds): - """ :meta private: """ - - if(self.scale == 'max_dist_total'): - if(not hasattr(self, 'max_scale_num')): + + if self.scale == "max_dist_total": + if not hasattr(self, "max_scale_num"): max_dist = 0 for _ in range(10): - i,j = np.random.choice(np.arange(len(self.point_clouds)), 2,replace = False) - if(self.dist_func_enc == 'GW' or self.dist_func_enc == 'GS'): - max_ij = np.max(scipy.spatial.distance.cdist(self.point_clouds[i], self.point_clouds[i])) + i, j = np.random.choice( + np.arange(len(self.point_clouds)), 2, replace=False + ) + if self.dist_func_enc == "GW" or self.dist_func_enc == "GS": + max_ij = np.max( + scipy.spatial.distance.cdist( + self.point_clouds[i], self.point_clouds[i] + ) + ) else: - max_ij = np.max(scipy.spatial.distance.cdist(self.point_clouds[i], self.point_clouds[j])) + max_ij = np.max( + scipy.spatial.distance.cdist( + self.point_clouds[i], self.point_clouds[j] + ) + ) max_dist = np.maximum(max_ij, max_dist) self.max_scale_num = max_dist else: - print("Using Calculated Max Dist Scaling Values") - return(point_clouds/self.max_scale_num) - if(self.scale == 'max_dist_each'): - print("Using Per Sample Max Dist") - pc_scale = np.asarray([pc/np.max(scipy.spatial.distance.pdist(pc)) for pc in point_clouds]) - return(pc_scale) - if(self.scale == 'min_max_each'): - print("Scaling Per Sample") - max_val = point_clouds.max(axis = 1, keepdims = True) - min_val = point_clouds.min(axis = 1, keepdims = True) - return(2 * (point_clouds - min_val)/(max_val - min_val) - 1) - elif(self.scale == 'min_max_total'): - if(not hasattr(self, 'max_val')): - self.max_val = self.point_clouds.max(axis = ((0,1)), keepdims = True) - self.min_val = self.point_clouds.min(axis = ((0,1)), keepdims = True) + print("Using Calculated Max Dist Scaling Values") + return point_clouds / self.max_scale_num + if self.scale == "max_dist_each": + print("Using Per Sample Max Dist") + pc_scale = np.asarray( + [pc / np.max(scipy.spatial.distance.pdist(pc)) for pc in point_clouds] + ) + return pc_scale + if self.scale == "min_max_each": + print("Scaling Per Sample") + max_val = point_clouds.max(axis=1, keepdims=True) + min_val = point_clouds.min(axis=1, keepdims=True) + return 2 * (point_clouds - min_val) / (max_val - min_val) - 1 + elif self.scale == "min_max_total": + if not hasattr(self, "max_val"): + self.max_val = self.point_clouds.max(axis=((0, 1)), keepdims=True) + self.min_val = self.point_clouds.min(axis=((0, 1)), keepdims=True) else: - print("Using Calculated Min Max Scaling Values") - return(2 * (point_clouds - self.min_val)/(self.max_val - self.min_val) - 1) - elif(self.scale == 'min_max_total_all_axis'): - if(not hasattr(self, 'max_val')): - self.max_val = self.point_clouds.max(keepdims = True) - self.min_val = self.point_clouds.min(keepdims = True) + print("Using Calculated Min Max Scaling Values") + return 2 * (point_clouds - self.min_val) / (self.max_val - self.min_val) - 1 + elif self.scale == "min_max_total_all_axis": + if not hasattr(self, "max_val"): + self.max_val = self.point_clouds.max(keepdims=True) + self.min_val = self.point_clouds.min(keepdims=True) else: - print("Using Calculated Min Max Scaling Values") - return(2 * (point_clouds - self.min_val)/(self.max_val - self.min_val) - 1) + print("Using Calculated Min Max Scaling Values") + return 2 * (point_clouds - self.min_val) / (self.max_val - self.min_val) - 1 else: - return(point_clouds) - - def encode(self, pc, weights, max_batch = 256): - + return point_clouds + + def encode(self, pc, weights, max_batch=256): """ Encode point clouds with trained Wormhole model @@ -187,21 +250,29 @@ def encode(self, pc, weights, max_batch = 256): :param max_batch: (int) maximum size of batch during inference calls to Wormhole (default 256) :return enc: per point cloud embeddings - """ - - if(pc.shape[0] < max_batch): - enc = self.model.bind({'params': self.params}).Encoder(pc, weights, deterministic = True) - else: # For when the GPU can't pass all point-clouds at once - num_split = int(pc.shape[0]/max_batch)+1 + """ + + if pc.shape[0] < max_batch: + enc = self.model.bind({"params": self.params}).Encoder( + pc, weights, deterministic=True + ) + else: # For when the GPU can't pass all point-clouds at once + num_split = int(pc.shape[0] / max_batch) + 1 pc_split = np.array_split(pc, num_split) mask_split = np.array_split(weights, num_split) - - enc = np.concatenate([self.model.bind({'params': self.params}).Encoder(pc_split[split_ind], mask_split[split_ind], deterministic = True) for - split_ind in range(num_split)], axis = 0) + + enc = np.concatenate( + [ + self.model.bind({"params": self.params}).Encoder( + pc_split[split_ind], mask_split[split_ind], deterministic=True + ) + for split_ind in range(num_split) + ], + axis=0, + ) return enc - - def decode(self, enc, max_batch = 256): - + + def decode(self, enc, max_batch=256): """ Decode embedding back into point clouds using Wormhole decoder @@ -210,111 +281,169 @@ def decode(self, enc, max_batch = 256): :param max_batch: (int) maximum size of batch during inference calls to Wormhole (default 256) :return dec: decoded point clouds from embeddings - """ - - if(enc.shape[0]