From b135d4f4f2f392f53063c3da62267de78b2ab9aa Mon Sep 17 00:00:00 2001 From: YannisZa Date: Mon, 22 Apr 2024 15:46:38 +0100 Subject: [PATCH] added option to evaluate metrics at specific cells --- gensit/utils/math_utils.py | 98 ++++++++++++++++++++++---------------- 1 file changed, 57 insertions(+), 41 deletions(-) diff --git a/gensit/utils/math_utils.py b/gensit/utils/math_utils.py index 6d1b3b4..b25f268 100644 --- a/gensit/utils/math_utils.py +++ b/gensit/utils/math_utils.py @@ -212,13 +212,13 @@ def srmse(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs): ground_truth = ground_truth.astype('float32') prediction,ground_truth = xr.broadcast(prediction,ground_truth) prediction,ground_truth = xr.align(prediction,ground_truth, join='exact') - test_cells = kwargs.get('test_cells',None) + cells = kwargs.get('cells',None) - if test_cells is not None: + if cells is not None: # Mask all non test cells mask = ground_truth.copy(deep=True) mask[:] = False - for cell in test_cells: + for cell in cells: mask[cell[0],cell[1]] = True # Apply mask prediction = prediction.where(mask) @@ -249,7 +249,7 @@ def ssi(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs): """ - test_cells = kwargs.get('test_cells',None) + cells = kwargs.get('cells',None) # Compute denominator denominator = (ground_truth + prediction) denominator = xr.where(denominator <= 0, 1., denominator) @@ -257,11 +257,11 @@ def ssi(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs): numerator = 2*np.minimum(ground_truth,prediction) ratio = numerator / denominator - if test_cells is not None: + if cells is not None: # Mask all non test cells mask = ratio.copy(deep=True) mask[:] = False - for cell in test_cells: + for cell in cells: mask[cell[0],cell[1]] = True # Apply mask ratio = ratio.where(mask) @@ -270,45 +270,30 @@ def ssi(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs): ssi = ratio.mean(dim=['origin','destination'],skipna=True) return ssi -def von_neumann_entropy(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs): - # Convert matrix to square - matrix = (prediction@prediction.T).astype('float32') - # Add jitter - matrix += kwargs['epsilon_threshold'] * torch.eye(matrix.shape[0],dtype='float32') - # Find eigenvalues - eigenval = torch.real(torch.linalg.eigvals(matrix)) - # Get all non-zero eigenvalues - eigenval = eigenval[~torch.isclose(eigenval,0,atol = 1e-08)] - # Compute entropy - res = torch.sum(-eigenval*torch.log(eigenval)).to(dtype = float32) - - return res - -def sparsity(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs): - """Computes percentage of zero cells in table - - Parameters - ---------- - prediction : xr.DataArray - Description of parameter `table`. - - Returns - ------- - float - Description of returned object. - - """ - N,_,_ = shape(prediction) - res = np.count_nonzero(prediction==0)/np.prod(prediction.size) - return res - +def markov_basis_distance(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs): + cells = kwargs.get('cells',None) + if cells is not None: + # Mask all non test cells + mask = ground_truth.copy(deep=True) + mask[:] = False + for cell in cells: + mask[cell[0],cell[1]] = True + # Apply mask + prediction = prediction.where(mask) + ground_truth = ground_truth.where(mask) + + return np.abs(prediction - ground_truth).sum( + dims = ['origin','destination'], + dtype = 'float64', + skipna = True + ) / 2 def coverage_probability(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs): # Get region mass region_mass = kwargs.get('region_mass',0.95) # Get test cells - test_cells = kwargs.get('test_cells',None) + cells = kwargs.get('cells',None) # High posterior density mass alpha = 1-region_mass @@ -330,11 +315,11 @@ def coverage_probability(prediction:xr.DataArray,ground_truth:xr.DataArray=None, # Compute flag for whether ground truth table is covered cell_coverage = (ground_truth >= lower_bound_hpdr) & (ground_truth <= upper_bound_hpdr) - if test_cells is not None: + if cells is not None: # Mask all non test cells mask = cell_coverage.copy(deep=True) mask[:] = False - for cell in test_cells: + for cell in cells: mask[cell[0],cell[1]] = True # Apply mask return cell_coverage.where(mask) @@ -386,6 +371,37 @@ def calculate_min_interval(x, alpha): return hdi_min, hdi_max +def von_neumann_entropy(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs): + # Convert matrix to square + matrix = (prediction@prediction.T).astype('float32') + # Add jitter + matrix += kwargs['epsilon_threshold'] * torch.eye(matrix.shape[0],dtype='float32') + # Find eigenvalues + eigenval = torch.real(torch.linalg.eigvals(matrix)) + # Get all non-zero eigenvalues + eigenval = eigenval[~torch.isclose(eigenval,0,atol = 1e-08)] + # Compute entropy + res = torch.sum(-eigenval*torch.log(eigenval)).to(dtype = float32) + + return res + +def sparsity(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs): + """Computes percentage of zero cells in table + + Parameters + ---------- + prediction : xr.DataArray + Description of parameter `table`. + + Returns + ------- + float + Description of returned object. + + """ + N,_,_ = shape(prediction) + res = np.count_nonzero(prediction==0)/np.prod(prediction.size) + return res def logsumexp(input, dim = None): max_val = input.max(dim = dim)