Skip to content

Commit

Permalink
Merge pull request #16 from oh-yu/develop
Browse files Browse the repository at this point in the history
No.8 Production PR
  • Loading branch information
oh-yu authored Jun 10, 2024
2 parents 643c6d6 + 0efa437 commit 45d93b0
Show file tree
Hide file tree
Showing 14 changed files with 519 additions and 376 deletions.
97 changes: 97 additions & 0 deletions algo/algo_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch

from ..utils import utils


def get_psuedo_label_weights(
source_Y_batch: torch.Tensor, thr: float = 0.75, alpha: int = 1, device=utils.DEVICE
) -> torch.Tensor:
"""
# TODO: attach paper
Parameters
----------
source_Y_batch : torch.Tensor of shape(N, 2)
thr : float
Returns
-------
psuedo_label_weights : torch.Tensor of shape(N, )
"""
output_size = source_Y_batch[:, :-1].shape[1]
psuedo_label_weights = []

if output_size == 1:
pred_y = source_Y_batch[:, utils.COL_IDX_TASK]
for i in pred_y:
if i > thr:
psuedo_label_weights.append(1)
elif i < 1 - thr:
psuedo_label_weights.append(1)
else:
if i > 0.5:
psuedo_label_weights.append(i ** alpha + (1 - thr))
else:
psuedo_label_weights.append((1 - i) ** alpha + (1 - thr))

else:
pred_y = source_Y_batch[:, :output_size]
pred_y = torch.max(pred_y, axis=1).values
for i in pred_y:
if i > thr:
psuedo_label_weights.append(1)
else:
psuedo_label_weights.append(i ** alpha + (1 - thr))
return torch.tensor(psuedo_label_weights, dtype=torch.float32).to(device)


def get_terminal_weights(
is_target_weights: bool,
is_class_weights: bool,
is_psuedo_weights: bool,
pred_source_y_domain: torch.Tensor,
source_y_task_batch: torch.Tensor,
psuedo_label_weights: torch.Tensor,
) -> torch.Tensor:
"""
# TODO: attach paper
Parameters
----------
is_target_weights: bool
is_class_weights: bool
is_psuedo_weights: bool
pred_source_y_domain : torch.Tensor of shape(N, )
source_y_task_batch : torch.Tensor of shape(N, )
psuedo_label_weights : torch.Tensor of shape(N, )
Returns
-------
weights : torch.Tensor of shape(N, )
terminal sample weights for nn.BCELoss
"""
if is_target_weights:
target_weights = pred_source_y_domain / (1 - pred_source_y_domain)
else:
target_weights = 1
if is_class_weights:
class_weights = _get_class_weights(source_y_task_batch)
else:
class_weights = 1
if is_psuedo_weights:
weights = target_weights * class_weights * psuedo_label_weights
else:
weights = target_weights * class_weights
return weights


def _get_class_weights(source_y_task_batch):
p_occupied = sum(source_y_task_batch) / source_y_task_batch.shape[0]
p_unoccupied = 1 - p_occupied
class_weights = torch.zeros_like(source_y_task_batch)
for i, y in enumerate(source_y_task_batch):
if y == 1:
class_weights[i] = p_unoccupied
elif y == 0:
class_weights[i] = p_occupied
return class_weights
190 changes: 95 additions & 95 deletions algo/coral_algo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import matplotlib.pyplot as plt
from typing import List

import torch
from torch import nn, optim
from torch import nn

from ..networks import ThreeLayersDecoder
from ..utils import utils
from .algo_utils import get_psuedo_label_weights


def get_MSE(x, y):
Expand All @@ -20,120 +21,119 @@ def get_covariance_matrix(x, y):
return cov_mat_x, cov_mat_y


def fit_coral(
source_loader, target_loader, num_epochs, task_classifier, criterion, optimizer, alpha, target_X, target_y_task
):
def fit(data, network, **kwargs):
# Args
source_loader, target_loader = data["source_loader"], data["target_loader"]
target_X, target_y_task = data["target_X"], data["target_y_task"]

feature_extractor = network["feature_extractor"]
task_classifier = network["task_classifier"]
task_optimizer = network["task_optimizer"]
feature_optimizer = network["feature_optimizer"]

config = {
"num_epochs": 1000,
"alpha": 1,
"device": utils.DEVICE,
"is_psuedo_weights": False,
"is_changing_lr": False,
"epoch_thr_for_changing_lr": 200,
"changed_lrs": [0.00005, 0.00005],
"stop_during_epochs": False,
"epoch_thr_for_stopping": 2,
}
config.update(kwargs)
num_epochs = config["num_epochs"]
alpha = config["alpha"]
device = config["device"]
is_psuedo_weights = config["is_psuedo_weights"]
is_changing_lr = config["is_changing_lr"]
epoch_thr_for_changing_lr = config["epoch_thr_for_changing_lr"]
changed_lrs = config["changed_lrs"]
stop_during_epochs = config["stop_during_epochs"]
epoch_thr_for_stopping = config["epoch_thr_for_stopping"]

# Fit
for epoch in range(1, num_epochs + 1):
task_classifier.train()
feature_extractor.train()
if stop_during_epochs & (epoch == epoch_thr_for_stopping):
break
if is_changing_lr:
feature_optimizer, task_optimizer = _change_lr_during_coral_training(
feature_optimizer, task_optimizer, epoch, epoch_thr=epoch_thr_for_changing_lr, changed_lrs=changed_lrs,
)

for (source_X_batch, source_Y_batch), (target_X_batch, _) in zip(source_loader, target_loader):
# 0. Data
source_y_task_batch = source_Y_batch[:, utils.COL_IDX_TASK] > 0.5
source_y_task_batch = source_y_task_batch.to(torch.float32)
if task_classifier.output_size == 1:
source_y_task_batch = source_Y_batch[:, utils.COL_IDX_TASK] > 0.5
source_y_task_batch = source_y_task_batch.to(torch.float32)
else:
if is_psuedo_weights:
output_size = source_Y_batch[:, :-1].shape[1]
source_y_task_batch = source_Y_batch[:, :output_size]
source_y_task_batch = torch.argmax(source_y_task_batch, dim=1)
source_y_task_batch = source_y_task_batch.to(torch.long)
else:
source_y_task_batch = source_Y_batch[:, utils.COL_IDX_TASK]
source_y_task_batch = source_y_task_batch.to(torch.long)

if is_psuedo_weights:
weights = get_psuedo_label_weights(source_Y_batch=source_Y_batch, device=device).detach()
else:
weights = torch.ones_like(source_y_task_batch)

# 1. Forward
source_X_batch = feature_extractor(source_X_batch)
target_X_batch = feature_extractor(target_X_batch)
source_out = task_classifier(source_X_batch)
target_out = task_classifier(target_X_batch)

# 1.1 Task Loss
source_preds = torch.sigmoid(source_out).reshape(-1)
loss_task = criterion(source_preds, source_y_task_batch)
if task_classifier.output_size == 1:
source_preds = torch.sigmoid(source_out).reshape(-1)
criterion_weight = nn.BCELoss(weight=weights)
loss_task = criterion_weight(source_preds, source_y_task_batch)
else:
source_preds = torch.softmax(source_out, dim=1)
criterion_weight = nn.CrossEntropyLoss(reduction="none")
loss_task = criterion_weight(source_preds, source_y_task_batch)
loss_task = loss_task * weights
loss_task = loss_task.mean()

# 1.2 CoRAL Loss
cov_mat_source, cov_mat_target = get_covariance_matrix(source_out, target_out)
k = source_out.shape[1]
loss_coral = get_MSE(cov_mat_source, cov_mat_target) * (1 / (4 * k ** 2))
loss = loss_task + loss_coral * alpha
# 2. Backward
optimizer.zero_grad()
task_optimizer.zero_grad()
feature_optimizer.zero_grad()
loss.backward()
# 3. Update Params
optimizer.step()
task_optimizer.step()
feature_optimizer.step()

# 4. Eval
with torch.no_grad():
feature_extractor.eval()
task_classifier.eval()
target_out = task_classifier(target_X)
target_out = torch.sigmoid(target_out).reshape(-1)
target_out = target_out > 0.5
target_out = task_classifier.predict(feature_extractor(target_X))
acc = sum(target_out == target_y_task) / len(target_y_task)
if epoch % 10 == 0:
print(f"Epoch: {epoch}, Loss Coral: {loss_coral}, Loss Task: {loss_task}, Acc: {acc}")
return task_classifier


if __name__ == "__main__":
# Load Data
(
source_X,
target_X,
source_y_task,
target_y_task,
x_grid,
x1_grid,
x2_grid,
) = utils.get_source_target_from_make_moons()
source_loader, target_loader, source_y_task, source_X, target_X, target_y_task = utils.get_loader(
source_X, target_X, source_y_task, target_y_task
)

# Init NN
num_classes = 1
task_classifier = ThreeLayersDecoder(input_size=2, output_size=num_classes, fc1_size=50, fc2_size=10).to(
utils.DEVICE
)
learning_rate = 0.01

criterion = nn.BCELoss()
task_optimizer = optim.Adam(task_classifier.parameters(), lr=learning_rate)

# Fit CoRAL
task_classifier = fit_coral(
source_loader,
target_loader,
num_epochs=500,
task_classifier=task_classifier,
criterion=criterion,
optimizer=task_optimizer,
alpha=1,
target_X=target_X,
target_y_task=target_y_task,
)
source_X = source_X.cpu()
target_X = target_X.cpu()
x_grid = torch.tensor(x_grid, dtype=torch.float32).to(utils.DEVICE)
y_grid = task_classifier(x_grid.T)
y_grid = torch.sigmoid(y_grid).cpu().detach().numpy()

plt.figure()
plt.title("Domain Adaptation Boundary")
plt.xlabel("X1")
plt.ylabel("X2")
plt.scatter(source_X[:, 0], source_X[:, 1], c=source_y_task)
plt.scatter(target_X[:, 0], target_X[:, 1], c="black")
plt.contourf(x1_grid, x2_grid, y_grid.reshape(100, 100), alpha=0.3)
plt.colorbar()
plt.show()

# Without DA
task_classifier = ThreeLayersDecoder(input_size=source_X.shape[1], output_size=num_classes).to(utils.DEVICE)
task_optimizer = optim.Adam(task_classifier.parameters(), lr=learning_rate)
task_classifier = utils.fit_without_adaptation(
source_loader, task_classifier, task_optimizer, criterion, num_epochs=500
)
pred_y_task = task_classifier(target_X.to(utils.DEVICE))
pred_y_task = torch.sigmoid(pred_y_task).reshape(-1)
pred_y_task = pred_y_task > 0.5
acc = sum(pred_y_task == target_y_task) / target_y_task.shape[0]
print(f"Without Adaptation Accuracy:{acc}")

y_grid = task_classifier(x_grid.T)
y_grid = torch.sigmoid(y_grid).cpu().detach().numpy()

plt.figure()
plt.title("Without Adaptation Boundary")
plt.xlabel("X1")
plt.ylabel("X2")
plt.scatter(source_X[:, 0], source_X[:, 1], c=source_y_task)
plt.scatter(target_X[:, 0], target_X[:, 1], c="black")
plt.contourf(x1_grid, x2_grid, y_grid.reshape(100, 100), alpha=0.3)
plt.colorbar()
plt.show()
return feature_extractor, task_classifier, None


def _change_lr_during_coral_training(
feature_optimizer: torch.optim.Adam,
task_optimizer: torch.optim.Adam,
epoch: torch.Tensor,
epoch_thr: int = 200,
changed_lrs: List[float] = [0.00005],
):
if epoch == epoch_thr:
feature_optimizer.param_groups[0]["lr"] = changed_lrs[0]
task_optimizer.param_groups[0]["lr"] = changed_lrs[0]
return feature_optimizer, task_optimizer
4 changes: 2 additions & 2 deletions algo/dan_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_MMD(x, y):
return mmd_xx + mmd_yy + mmd_xy


def fit_dan(
def fit(
source_loader,
target_loader,
num_epochs,
Expand Down Expand Up @@ -131,7 +131,7 @@ def fit_dan(
task_optimizer_source = optim.Adam(task_classifier_source.parameters(), lr=learning_rate)

# Fit DAN
feature_extractor, task_classifier = fit_dan(
feature_extractor, task_classifier = fit(
source_loader=source_loader,
target_loader=target_loader,
num_epochs=100,
Expand Down
Loading

0 comments on commit 45d93b0

Please sign in to comment.