Skip to content

Commit

Permalink
different diva
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun committed Oct 9, 2023
1 parent c496add commit bd65d15
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 15 deletions.
2 changes: 1 addition & 1 deletion domainlab/algos/builder_diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def init_business(self, exp):
request = RequestVAEBuilderCHW(
task.isize.c, task.isize.h, task.isize.w, args)
node = VAEChainNodeGetter(request)()
model = mk_diva()(node,
model = mk_diva(str_mu=args.str_mu)(node,
zd_dim=args.zd_dim,
zy_dim=args.zy_dim,
zx_dim=args.zx_dim,
Expand Down
3 changes: 3 additions & 0 deletions domainlab/algos/trainers/args_fbopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def add_args2parser_fbopt(parser):
parser.add_argument('--no_setpoint_update', action='store_true', default=False,
help='disable setpoint update')

parser.add_argument('--str_mu', type=str, default="default", help='which penalty to tune')


# the following hyperparamters do not need to be tuned
parser.add_argument('--beta_mu', type=float, default=1.1,
help='how much to multiply mu each time')
Expand Down
5 changes: 1 addition & 4 deletions domainlab/algos/trainers/fbopt_alternate.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def cal_delta4control(self, list1, list_setpoint):
def cal_delta_integration(self, list_old, list_new, coeff):
return [(1-coeff)*a + coeff*b for a, b in zip(list_old, list_new)]

def search_mu(self, epo_reg_loss, epo_task_loss, dict_theta=None, miter=None):
def search_mu(self, epo_reg_loss, epo_task_loss, epo_loss_tr, dict_theta=None, miter=None):
"""
start from parameter dictionary dict_theta: {"layer":tensor},
enlarge mu w.r.t. its current value
Expand Down Expand Up @@ -137,9 +137,6 @@ def search_mu(self, epo_reg_loss, epo_task_loss, dict_theta=None, miter=None):
f'reg/setpoint{i}': reg_set,
}, miter)
self.writer.add_scalar(f'x-axis=task vs y-axis=reg/dyn{i}', reg_dyn, epo_task_loss)

epo_loss_tr = epo_task_loss + torch.inner(
torch.Tensor(list(self.mmu.values())), torch.Tensor(epo_reg_loss))
self.writer.add_scalar('loss_penalized', epo_loss_tr, miter)
self.writer.add_scalar('task', epo_task_loss, miter)
acc_te = 0
Expand Down
12 changes: 8 additions & 4 deletions domainlab/algos/trainers/train_mu_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,27 @@ def eval_r_loss(self):
# mock the model hyper-parameter to be from dict4mu
epo_reg_loss = []
epo_task_loss = 0
epo_p_loss = 0
counter = 0.0
with torch.no_grad():
for _, (tensor_x, vec_y, vec_d, *_) in enumerate(self.loader_tr_no_drop):
tensor_x, vec_y, vec_d = \
tensor_x.to(self.device), vec_y.to(self.device), vec_d.to(self.device)
tuple_reg_loss = self.model.cal_reg_loss(tensor_x, vec_y, vec_d)
p_loss, *_ = self.model.cal_loss(tensor_x, vec_y, vec_d)
# NOTE: first [0] extract the loss, second [0] get the list
list_b_reg_loss = tuple_reg_loss[0]
list_b_reg_loss_sumed = [ele.sum().item() for ele in list_b_reg_loss]
list_b_reg_loss_sumed = [ele.sum().detach().item() for ele in list_b_reg_loss]
if len(epo_reg_loss) == 0:
epo_reg_loss = list_b_reg_loss_sumed
else:
epo_reg_loss = list(map(add, epo_reg_loss, list_b_reg_loss_sumed))
b_task_loss = self.model.cal_task_loss(tensor_x, vec_y).sum()
b_task_loss = self.model.cal_task_loss(tensor_x, vec_y).sum().detach().item()
# sum will kill the dimension of the mini batch
epo_task_loss += b_task_loss
epo_p_loss += p_loss.sum().detach().item()
counter += 1.0
return list_divide(epo_reg_loss, counter), epo_task_loss/counter
return list_divide(epo_reg_loss, counter), epo_task_loss/counter, epo_p_loss / counter

def before_batch(self, epoch, ind_batch):
"""
Expand All @@ -77,7 +80,7 @@ def before_batch(self, epoch, ind_batch):
def before_tr(self):
self.set_scheduler(scheduler=HyperSchedulerFeedbackAlternave)
self.model.hyper_update(epoch=None, fun_scheduler=HyperSetter(self.hyper_scheduler.mmu))
self.epo_reg_loss_tr, self.epo_task_loss_tr = self.eval_r_loss()
self.epo_reg_loss_tr, self.epo_task_loss_tr, self.epo_loss_tr = self.eval_r_loss()
self.hyper_scheduler.set_setpoint(
[ele * self.aconf.ini_setpoint_ratio for ele in self.epo_reg_loss_tr],
self.epo_task_loss_tr)
Expand All @@ -90,6 +93,7 @@ def tr_epoch(self, epoch):
self.hyper_scheduler.search_mu(
self.epo_reg_loss_tr,
self.epo_task_loss_tr,
self.epo_loss_tr,
dict(self.model.named_parameters()),
miter=epoch)
self.hyper_scheduler.update_setpoint(self.epo_reg_loss_tr, self.epo_task_loss_tr)
Expand Down
75 changes: 69 additions & 6 deletions domainlab/models/model_diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from domainlab.utils.utils_class import store_args


def mk_diva(parent_class=VAEXYDClassif):
def mk_diva(parent_class=VAEXYDClassif, str_mu="default"):
"""
Instantiate a domain invariant variational autoencoder (DIVA) with arbitrary task loss.
Expand Down Expand Up @@ -89,8 +89,6 @@ def hyper_update(self, epoch, fun_scheduler):
self.beta_d = dict_rst["beta_d"]
self.beta_y = dict_rst["beta_y"]
self.beta_x = dict_rst["beta_x"]
self.gamma_d = dict_rst["gamma_d"]
self.mu_recon = dict_rst["mu_recon"]

def hyper_init(self, functor_scheduler, trainer=None):
"""
Expand All @@ -100,11 +98,9 @@ def hyper_init(self, functor_scheduler, trainer=None):
"""
return functor_scheduler(
trainer=trainer,
mu_recon=self.mu_recon,
beta_d=self.beta_d,
beta_y=self.beta_y,
beta_x=self.beta_x,
gamma_d=self.gamma_d,
)

def get_list_str_y(self):
Expand Down Expand Up @@ -142,4 +138,71 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
lc_d = F.cross_entropy(logit_d, d_target, reduction="none")
return [loss_recon_x, zd_p_minus_zd_q, zx_p_minus_zx_q, zy_p_minus_zy_q, lc_d], \
[self.mu_recon, -self.beta_d, -self.beta_x, -self.beta_y, -self.gamma_d]
return ModelDIVA

class ModelDIVAGammadRecon(ModelDIVA):
def hyper_update(self, epoch, fun_scheduler):
"""hyper_update.
:param epoch:
:param fun_scheduler:
"""
dict_rst = fun_scheduler(epoch)
self.beta_d = dict_rst["beta_d"]
self.beta_y = dict_rst["beta_y"]
self.beta_x = dict_rst["beta_x"]
self.gamma_d = dict_rst["gamma_d"]
self.mu_recon = dict_rst["mu_recon"]

def hyper_init(self, functor_scheduler, trainer=None):
"""
initiate a scheduler object via class name and things inside this model
:param functor_scheduler: the class name of the scheduler
"""
return functor_scheduler(
trainer=trainer,
mu_recon=self.mu_recon,
beta_d=self.beta_d,
beta_y=self.beta_y,
beta_x=self.beta_x,
gamma_d=self.gamma_d,
)


class ModelDIVAGammad(ModelDIVA):
def hyper_update(self, epoch, fun_scheduler):
"""hyper_update.
:param epoch:
:param fun_scheduler:
"""
dict_rst = fun_scheduler(epoch)
self.beta_d = dict_rst["beta_d"]
self.beta_y = dict_rst["beta_y"]
self.beta_x = dict_rst["beta_x"]
self.gamma_d = dict_rst["gamma_d"]

def hyper_init(self, functor_scheduler, trainer=None):
"""
initiate a scheduler object via class name and things inside this model
:param functor_scheduler: the class name of the scheduler
"""
return functor_scheduler(
trainer=trainer,
beta_d=self.beta_d,
beta_y=self.beta_y,
beta_x=self.beta_x,
gamma_d=self.gamma_d,
)

class ModelDIVADefault(ModelDIVA):
"""
"""
if str_mu == "gammad_recon":
return ModelDIVAGammadRecon
if str_mu == "gammad":
return ModelDIVAGammad
if str_mu == "default":
return ModelDIVADefault
raise RuntimeError("not support argument candiates for str_mu: allowed: default, gammad_recon, gammad")

0 comments on commit bd65d15

Please sign in to comment.