Skip to content

Commit

Permalink
add dacon metric
Browse files Browse the repository at this point in the history
  • Loading branch information
Takshan authored Aug 29, 2024
1 parent d342dd2 commit 2965802
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
1 change: 1 addition & 0 deletions rahulscripts/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import * from custom_metrics
45 changes: 45 additions & 0 deletions rahulscripts/metrics/custom_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
from torchmetrics import Metric


class DaconScore(Metric):
def __init__(self, dist_sync_on_step=False, **kwargs):
super().__init__(dist_sync_on_step=dist_sync_on_step, **kwargs)
self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
self.unit = kwargs.get("unit", 9)

def update(self, pred: torch.Tensor, target: torch.Tensor):
score = self.score_compute(pred, target)
self.score += score
self.count += 1

def compute(self) -> torch.Tensor:
return self.score / self.count

def rmse_compute(self, pred, target):
return torch.sqrt(((pred - target) ** 2).mean())

def normalized_rmse_compute(self, pred, target):
return self.rmse_compute(pred, target) / (target.max() - target.min())

def correct_ratio(self, pred, target):
pIC50_pred = self.ic50_to_pic50(pred, self.unit)
pIC50_target = self.ic50_to_pic50(target, self.unit)
diff = pIC50_pred - pIC50_target
correct = torch.sum(diff <= 0.5)
return correct / len(pred)

def score_compute(self, pred, target):
if not isinstance(pred, torch.Tensor):
pred = torch.tensor(pred)
if not isinstance(target, torch.Tensor):
target = torch.tensor(target)
score = 0.5 * (1 - min(self.normalized_rmse_compute(pred, target), 1)) + (
0.5 * self.correct_ratio(pred, target)
)
return score

def ic50_to_pic50(self, ic50_value: float, unit: int = 9) -> float:
pic50 = unit - torch.log10(ic50_value)
return pic50

0 comments on commit 2965802

Please sign in to comment.