Skip to content

Commit

Permalink
added option to evaluate metrics at specific cells
Browse files Browse the repository at this point in the history
  • Loading branch information
YannisZa committed Apr 22, 2024
1 parent d1abdf5 commit b135d4f
Showing 1 changed file with 57 additions and 41 deletions.
98 changes: 57 additions & 41 deletions gensit/utils/math_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,13 @@ def srmse(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs):
ground_truth = ground_truth.astype('float32')
prediction,ground_truth = xr.broadcast(prediction,ground_truth)
prediction,ground_truth = xr.align(prediction,ground_truth, join='exact')
test_cells = kwargs.get('test_cells',None)
cells = kwargs.get('cells',None)

if test_cells is not None:
if cells is not None:
# Mask all non test cells
mask = ground_truth.copy(deep=True)
mask[:] = False
for cell in test_cells:
for cell in cells:
mask[cell[0],cell[1]] = True
# Apply mask
prediction = prediction.where(mask)
Expand Down Expand Up @@ -249,19 +249,19 @@ def ssi(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs):
"""

test_cells = kwargs.get('test_cells',None)
cells = kwargs.get('cells',None)
# Compute denominator
denominator = (ground_truth + prediction)
denominator = xr.where(denominator <= 0, 1., denominator)
# Compute numerator
numerator = 2*np.minimum(ground_truth,prediction)
ratio = numerator / denominator

if test_cells is not None:
if cells is not None:
# Mask all non test cells
mask = ratio.copy(deep=True)
mask[:] = False
for cell in test_cells:
for cell in cells:
mask[cell[0],cell[1]] = True
# Apply mask
ratio = ratio.where(mask)
Expand All @@ -270,45 +270,30 @@ def ssi(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs):
ssi = ratio.mean(dim=['origin','destination'],skipna=True)
return ssi

def von_neumann_entropy(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs):
# Convert matrix to square
matrix = (prediction@prediction.T).astype('float32')
# Add jitter
matrix += kwargs['epsilon_threshold'] * torch.eye(matrix.shape[0],dtype='float32')
# Find eigenvalues
eigenval = torch.real(torch.linalg.eigvals(matrix))
# Get all non-zero eigenvalues
eigenval = eigenval[~torch.isclose(eigenval,0,atol = 1e-08)]
# Compute entropy
res = torch.sum(-eigenval*torch.log(eigenval)).to(dtype = float32)

return res

def sparsity(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs):
"""Computes percentage of zero cells in table
Parameters
----------
prediction : xr.DataArray
Description of parameter `table`.
Returns
-------
float
Description of returned object.
"""
N,_,_ = shape(prediction)
res = np.count_nonzero(prediction==0)/np.prod(prediction.size)
return res

def markov_basis_distance(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs):
cells = kwargs.get('cells',None)
if cells is not None:
# Mask all non test cells
mask = ground_truth.copy(deep=True)
mask[:] = False
for cell in cells:
mask[cell[0],cell[1]] = True
# Apply mask
prediction = prediction.where(mask)
ground_truth = ground_truth.where(mask)

return np.abs(prediction - ground_truth).sum(
dims = ['origin','destination'],
dtype = 'float64',
skipna = True
) / 2

def coverage_probability(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs):
# Get region mass
region_mass = kwargs.get('region_mass',0.95)

# Get test cells
test_cells = kwargs.get('test_cells',None)
cells = kwargs.get('cells',None)

# High posterior density mass
alpha = 1-region_mass
Expand All @@ -330,11 +315,11 @@ def coverage_probability(prediction:xr.DataArray,ground_truth:xr.DataArray=None,
# Compute flag for whether ground truth table is covered
cell_coverage = (ground_truth >= lower_bound_hpdr) & (ground_truth <= upper_bound_hpdr)

if test_cells is not None:
if cells is not None:
# Mask all non test cells
mask = cell_coverage.copy(deep=True)
mask[:] = False
for cell in test_cells:
for cell in cells:
mask[cell[0],cell[1]] = True
# Apply mask
return cell_coverage.where(mask)
Expand Down Expand Up @@ -386,6 +371,37 @@ def calculate_min_interval(x, alpha):

return hdi_min, hdi_max

def von_neumann_entropy(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs):
# Convert matrix to square
matrix = (prediction@prediction.T).astype('float32')
# Add jitter
matrix += kwargs['epsilon_threshold'] * torch.eye(matrix.shape[0],dtype='float32')
# Find eigenvalues
eigenval = torch.real(torch.linalg.eigvals(matrix))
# Get all non-zero eigenvalues
eigenval = eigenval[~torch.isclose(eigenval,0,atol = 1e-08)]
# Compute entropy
res = torch.sum(-eigenval*torch.log(eigenval)).to(dtype = float32)

return res

def sparsity(prediction:xr.DataArray,ground_truth:xr.DataArray=None,**kwargs):
"""Computes percentage of zero cells in table
Parameters
----------
prediction : xr.DataArray
Description of parameter `table`.
Returns
-------
float
Description of returned object.
"""
N,_,_ = shape(prediction)
res = np.count_nonzero(prediction==0)/np.prod(prediction.size)
return res

def logsumexp(input, dim = None):
max_val = input.max(dim = dim)
Expand Down

0 comments on commit b135d4f

Please sign in to comment.