Skip to content

Commit

Permalink
made metrics calculation optional
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikhail Lebedev committed Nov 15, 2023
1 parent 80a7e09 commit d53c9bf
Showing 1 changed file with 92 additions and 44 deletions.
136 changes: 92 additions & 44 deletions alphadia/fdrexperimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ def __init__(
weight_decay: float = 0.00001,
layers: typing.List[int] = [100, 50, 20, 5],
dropout: float = 0.001,
metric_interval: int = 1000,
num_workers: int = 0,
calculate_metrics: bool = False,
metric_interval: int = 1,
patience: int = 3,
):
"""Binary Classifier using a feed forward neural network.
Expand Down Expand Up @@ -165,8 +166,15 @@ def __init__(
dropout : float, default=0.001
Dropout probability for training.
metric_interval : int, default=1000
Interval for logging metrics during training.
calculate_metrics : bool, default=False
Whether to calculate metrics during training.
metric_interval : int, default=1
Interval for logging metrics during training, once per metric_interval epochs.
patience : int, default=3
Number of epochs to wait for improvement before early stopping.
"""

self.test_size = test_size
Expand All @@ -179,6 +187,8 @@ def __init__(
self.input_dim = input_dim
self.output_dim = output_dim
self.metric_interval = metric_interval
self.calculate_metrics = calculate_metrics
self.patience = patience

self.network = None
self.optimizer = None
Expand Down Expand Up @@ -259,6 +269,32 @@ def from_state_dict(self, state_dict: dict):

self.__dict__.update(_state_dict)

def _prepare_data(self, x: np.ndarray, y: np.ndarray):
"""Prepare the data for training: normalize, split into train and test set.
Parameters
----------
x : np.array, dtype=float
Training data of shape (n_samples, n_features).
y : np.array, dtype=int
Target values of shape (n_samples,) or (n_samples, n_classes).
"""
x -= x.mean(axis=0)
x /= x.std(axis=0) + 1e-6

if y.ndim == 1:
y = np.stack([1 - y, y], axis=1)
x_train, x_test, y_train, y_test = model_selection.train_test_split(
x, y, test_size=self.test_size
)
x_train = torch.from_numpy(x_train).float()
y_train = torch.from_numpy(y_train).float()
x_test = torch.from_numpy(x_test).float()
y_test = torch.from_numpy(y_test).float()
return x_train, x_test, y_train, y_test

def fit(self, x: np.ndarray, y: np.ndarray):
"""Fit the classifier to the data.
Expand Down Expand Up @@ -291,38 +327,25 @@ def fit(self, x: np.ndarray, y: np.ndarray):
dropout=self.dropout,
)

best_test_accuracy = 0.0
patience = 5
counter_patience = 0

# normalize input
x -= x.mean(axis=0)
x /= x.std(axis=0) + 1e-6

if y.ndim == 1:
y = np.stack([1 - y, y], axis=1)
x_train, x_test, y_train, y_test = model_selection.train_test_split(
x, y, test_size=self.test_size
)

x_train = torch.from_numpy(x_train).float()
y_train = torch.from_numpy(y_train).float()
x_test = torch.from_numpy(x_test).float()
y_test = torch.from_numpy(y_test).float()


optimizer = optim.AdamW(
self.network.parameters(),
lr = self.learning_rate,
weight_decay = self.weight_decay,
lr=self.learning_rate,
weight_decay=self.weight_decay,
)

loss = nn.BCELoss()

batch_count = 0
best_train_accuracy = 0.0
best_test_accuracy = 0.0
patience = self.patience
x_train, x_test, y_train, y_test = self._prepare_data(x, y)

for j in tqdm(range(self.epochs)):
num_batches = (x_train.shape[0] // self.batch_size) - 1
batch_start_list = np.arange(num_batches) * self.batch_size
batch_stop_list = np.arange(num_batches) * self.batch_size + self.batch_size

batch_count = 0
for epoch in tqdm(range(self.epochs)):
train_loss_sum = 0.0
train_accuracy_sum = 0.0
test_loss_sum = 0.0
Expand All @@ -331,42 +354,67 @@ def fit(self, x: np.ndarray, y: np.ndarray):
num_batches_train = 0
num_batches_test = 0

num_batches = (x_train.shape[0] // self.batch_size) -1
batch_start_list = np.arange(num_batches) * self.batch_size
batch_stop_list = np.arange(num_batches) * self.batch_size + self.batch_size

# shuffle batches
order = np.random.permutation(num_batches)
batch_start_list = batch_start_list[order]
batch_stop_list = batch_stop_list[order]

for (batch_start, batch_stop) in zip(batch_start_list, batch_stop_list):
for batch_start, batch_stop in zip(batch_start_list, batch_stop_list):
y_pred = self.network(x_train[batch_start:batch_stop])
loss_value = loss(y_pred, y_train[batch_start:batch_stop])

optimizer.zero_grad()
loss_value.backward()
optimizer.step()

train_loss_sum += loss_value.detach()
train_accuracy_sum += (
(y_train[batch_start:batch_stop][:, 1] == y_pred.argmax(dim=1)).float().mean()
(y_train[batch_start:batch_stop][:, 1] == y_pred.argmax(dim=1))
.float()
.mean()
)
num_batches_train += 1

if not self.calculate_metrics:
# check for early stopping
average_train_accuracy = train_accuracy_sum / num_batches_train
if average_train_accuracy > best_train_accuracy:
best_train_accuracy = average_train_accuracy
patience = self.patience
else:
patience -= 1

if patience <= 0:
break
continue

if epoch % self.metric_interval != 0: # skip metrics if wrong epoch
continue

self.network.eval()
with torch.no_grad():
test_num_batches = (x_test.shape[0] // self.batch_size) -1
test_num_batches = (x_test.shape[0] // self.batch_size) - 1
test_batch_start_list = np.arange(test_num_batches) * self.batch_size
test_batch_stop_list = np.arange(test_num_batches) * self.batch_size + self.batch_size
test_batch_stop_list = (
np.arange(test_num_batches) * self.batch_size + self.batch_size
)

for (batch_start, batch_stop) in zip(test_batch_start_list, test_batch_stop_list):
for batch_start, batch_stop in zip(
test_batch_start_list, test_batch_stop_list
):
batch_x_test = x_test[batch_start:batch_stop]
batch_y_test = y_test[batch_start:batch_stop]

y_pred_test = self.network(batch_x_test)
test_loss = loss(y_pred_test, batch_y_test)
test_accuracy = (y_test[batch_start:batch_stop][:, 1] == y_pred_test.argmax(dim=1)).float().mean()
test_accuracy = (
(
y_test[batch_start:batch_stop][:, 1]
== y_pred_test.argmax(dim=1)
)
.float()
.mean()
)
num_batches_test += 1
test_accuracy_sum += test_accuracy
test_loss_sum += test_loss
Expand All @@ -385,19 +433,19 @@ def fit(self, x: np.ndarray, y: np.ndarray):

self.metrics["test_loss"].append(average_test_loss.item())
self.metrics["test_accuracy"].append(average_test_accuracy.item())
self.metrics["epoch"].append(j)
self.metrics["epoch"].append(epoch)

batch_count += num_batches_train
self.metrics["batch_count"].append(batch_count)

#check for early stopping
# check for early stopping
if average_test_accuracy > best_test_accuracy:
best_test_accuracy = average_test_accuracy
counter_patience = 0
patience = self.patience
else:
counter_patience += 1
patience -= 1

if counter_patience >= patience:
if patience <= 0:
break

self._fitted = True
Expand Down

0 comments on commit d53c9bf

Please sign in to comment.