From b2b204d831efa48226d13a62f7f6fe7d552bdeed Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 4 Oct 2024 16:48:45 +0200 Subject: [PATCH 01/40] copy branch fbopt_vector_ki_gain to mhof_dev --- domainlab/algos/trainers/fbopt_mu_controller.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py index 824638461..f61b4bc63 100644 --- a/domainlab/algos/trainers/fbopt_mu_controller.py +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -50,7 +50,8 @@ def __init__(self, trainer, **kwargs): self.mmu = {key: self.init_mu for key, val in self.mmu.items()} self.set_point_controller = FbOptSetpointController(args=self.trainer.aconf) - self.k_i_control = trainer.aconf.k_i_gain + self.k_i_control = [trainer.aconf.k_i_gain for i in + range(len(self.mmu))] self.k_i_gain_ratio = None self.overshoot_rewind = trainer.aconf.overshoot_rewind == "yes" self.delta_epsilon_r = None @@ -84,7 +85,7 @@ def set_k_i_gain(self, epo_reg_loss): k_i_gain_saturate_min = min(k_i_gain_saturate) # NOTE: here we override the commandline arguments specification # for k_i_control, so k_i_control is not a hyperparameter anymore - self.k_i_control = self.k_i_gain_ratio * k_i_gain_saturate_min + self.k_i_control = [self.k_i_gain_ratio * ele for ele in k_i_gain_saturate] warnings.warn( f"hyperparameter k_i_gain disabled! \ replace with {self.k_i_control}" @@ -162,7 +163,7 @@ def cal_activation(self): """ setpoint = self.get_setpoint4r() activation = [ - self.k_i_control * val if setpoint[i] > 0 else self.k_i_control * (-val) + self.k_i_control[i] * val if setpoint[i] > 0 else self.k_i_control[i] * (-val) for i, val in enumerate(self.delta_epsilon_r) ] if self.activation_clip is not None: From 6626475c39bad4b2851e13674a0aa24bbae2bc5a Mon Sep 17 00:00:00 2001 From: smilesun Date: Tue, 8 Oct 2024 16:50:56 +0200 Subject: [PATCH 02/40] use self.decoratee instead of self.model --- domainlab/algos/trainers/a_trainer.py | 1 + domainlab/algos/trainers/fbopt_mu_controller.py | 12 +++++++++--- domainlab/algos/trainers/train_fbopt_b.py | 3 ++- tests/test_fbopt_irm.py | 14 ++++++++++++++ 4 files changed, 26 insertions(+), 4 deletions(-) create mode 100644 tests/test_fbopt_irm.py diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 871abd8f0..890d5aea0 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -53,6 +53,7 @@ def __init__(self, successor_node=None, extend=None): """ super().__init__(successor_node) self._model = None + # decoratee can be both model or trainer self._decoratee = extend self.task = None self.observer = None diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py index f61b4bc63..5c635e147 100644 --- a/domainlab/algos/trainers/fbopt_mu_controller.py +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -52,7 +52,7 @@ def __init__(self, trainer, **kwargs): self.k_i_control = [trainer.aconf.k_i_gain for i in range(len(self.mmu))] - self.k_i_gain_ratio = None + self.k_i_gain_ratio = trainer.aconf.k_i_gain_ratio self.overshoot_rewind = trainer.aconf.overshoot_rewind == "yes" self.delta_epsilon_r = None @@ -71,7 +71,10 @@ def __init__(self, trainer, **kwargs): def set_k_i_gain(self, epo_reg_loss): if self.k_i_gain_ratio is None: - return + if self.k_i_control: + return + raise RuntimeError("set either direct k_i_control value or \ + set k_i_gain_ratio, can not be both empty!") # NOTE: do not use self.cal_delta4control!!!! which will change # class member variables self.delta_epsilon_r! list_setpoint = self.get_setpoint4r() @@ -79,9 +82,12 @@ def set_k_i_gain(self, epo_reg_loss): delta_epsilon_r = [a - b for a, b in zip(epo_reg_loss, list_setpoint)] # to calculate self.delta_epsilon_r + list_active = [self.activation_clip for i in range(len(delta_epsilon_r))] + k_i_gain_saturate = [ - a / b for a, b in zip(self.activation_clip, delta_epsilon_r) + a / b for a, b in zip(list_active, delta_epsilon_r) ] + k_i_gain_saturate_min = min(k_i_gain_saturate) # NOTE: here we override the commandline arguments specification # for k_i_control, so k_i_control is not a hyperparameter anymore diff --git a/domainlab/algos/trainers/train_fbopt_b.py b/domainlab/algos/trainers/train_fbopt_b.py index 1efe3ce58..e922a7521 100644 --- a/domainlab/algos/trainers/train_fbopt_b.py +++ b/domainlab/algos/trainers/train_fbopt_b.py @@ -66,7 +66,7 @@ def eval_r_loss(self): vec_y.to(self.device), vec_d.to(self.device), ) - tuple_reg_loss = self.model.cal_reg_loss(tensor_x, vec_y, vec_d, others) + tuple_reg_loss = self.decoratee.cal_reg_loss(tensor_x, vec_y, vec_d, others) p_loss, *_ = self.model.cal_loss(tensor_x, vec_y, vec_d, others) # NOTE: first [0] extract the loss, second [0] get the list list_b_reg_loss = tuple_reg_loss[0] @@ -127,6 +127,7 @@ def before_tr(self): ], self.epo_task_loss_tr, ) # setpoing w.r.t. random initialization of neural network + # FIXME: check if self.epo_reg_loss_tr is zero!! self.hyper_scheduler.set_k_i_gain(self.epo_reg_loss_tr) @property diff --git a/tests/test_fbopt_irm.py b/tests/test_fbopt_irm.py new file mode 100644 index 000000000..fb1f109f6 --- /dev/null +++ b/tests/test_fbopt_irm.py @@ -0,0 +1,14 @@ +""" + end-end test +""" +from tests.utils_test import utils_test_algo + + +def test_mhof_irm(): + """ + mhof-irm + """ + args = "--te_d=0 --task=mnistcolor10 --model=erm \ + --trainer=fbopt_irm --nname=conv_bn_pool_2 \ + --k_i_gain_ratio=0.5" + utils_test_algo(args) From d5d3a0b982d9d32b6a22a58c425af1e7a7f06e8f Mon Sep 17 00:00:00 2001 From: smilesun Date: Tue, 8 Oct 2024 16:52:39 +0200 Subject: [PATCH 03/40] cmd script to test mhof irm --- a_test_mhof_irm.sh | 1 + 1 file changed, 1 insertion(+) create mode 100644 a_test_mhof_irm.sh diff --git a/a_test_mhof_irm.sh b/a_test_mhof_irm.sh new file mode 100644 index 000000000..610d3a606 --- /dev/null +++ b/a_test_mhof_irm.sh @@ -0,0 +1 @@ +python main_out.py --te_d=0 --task=mnistcolor10 --model=erm --trainer=fbopt_irm --nname=conv_bn_pool_2 --k_i_gain_ratio=0.5 From 6701cbf3a8108a309bc655de9bb31a85b55dec79 Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 9 Oct 2024 11:36:17 +0200 Subject: [PATCH 04/40] enable grad for irm inside torch.no_grad for mhof --- domainlab/algos/trainers/train_irm.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/domainlab/algos/trainers/train_irm.py b/domainlab/algos/trainers/train_irm.py index 45797bf00..a2753989e 100644 --- a/domainlab/algos/trainers/train_irm.py +++ b/domainlab/algos/trainers/train_irm.py @@ -60,12 +60,13 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): _ = tensor_d _ = others y = tensor_y - phi = self._cal_phi(tensor_x) - dummy_w_scale = torch.tensor(1.).to(tensor_x.device).requires_grad_() - loss_1 = F.cross_entropy(phi[::2] * dummy_w_scale, y[::2]) - loss_2 = F.cross_entropy(phi[1::2] * dummy_w_scale, y[1::2]) - grad_1 = autograd.grad(loss_1, [dummy_w_scale], create_graph=True)[0] - grad_2 = autograd.grad(loss_2, [dummy_w_scale], create_graph=True)[0] - loss_irm_scalar = torch.sum(grad_1 * grad_2) # scalar - loss_irm_tensor = loss_irm_scalar.expand(tensor_x.shape[0]) - return [loss_irm_tensor], [self.aconf.gamma_reg] + with torch.enable_grad(): + phi = self._cal_phi(tensor_x) + dummy_w_scale = torch.tensor(1.).to(tensor_x.device).requires_grad_() + loss_1 = F.cross_entropy(phi[::2] * dummy_w_scale, y[::2]) + loss_2 = F.cross_entropy(phi[1::2] * dummy_w_scale, y[1::2]) + grad_1 = autograd.grad(loss_1, [dummy_w_scale], create_graph=True)[0] + grad_2 = autograd.grad(loss_2, [dummy_w_scale], create_graph=True)[0] + loss_irm_scalar = torch.sum(grad_1 * grad_2) # scalar + loss_irm_tensor = loss_irm_scalar.expand(tensor_x.shape[0]) + return [loss_irm_tensor], [self.aconf.gamma_reg] From 888d7143e92dc190ca123be1f79d54fba98cdfcd Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 9 Oct 2024 12:24:25 +0200 Subject: [PATCH 05/40] filter out zero reg loss in abstract trainer --- domainlab/algos/trainers/a_trainer.py | 4 ++++ domainlab/algos/trainers/train_irm.py | 2 +- domainlab/models/a_model_classif.py | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 890d5aea0..35a5b0fe2 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -298,6 +298,10 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): list_loss_tensor = list_reg_loss_model_tensor + \ list_reg_loss_trainer_tensor list_mu = list_mu_model + list_mu_trainer + # ERM return a tensor of all zeros, delete here + list_boolean_zero = [torch.all(torch.eq(list_loss_tensor[i], 0)).item() for i in range(len(list_mu))] + list_loss_tensor = [list_loss_tensor[i] for (i, flag) in enumerate(list_boolean_zero) if not flag] + list_mu = [list_mu[i] for (i, flag) in enumerate(list_boolean_zero) if not flag] return list_loss_tensor, list_mu def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): diff --git a/domainlab/algos/trainers/train_irm.py b/domainlab/algos/trainers/train_irm.py index a2753989e..377759a2f 100644 --- a/domainlab/algos/trainers/train_irm.py +++ b/domainlab/algos/trainers/train_irm.py @@ -37,7 +37,7 @@ def tr_epoch(self, epoch, flag_info=False): list_domain_loss_erm.append( self.model.cal_task_loss(tensor_x, tensor_y)) list_1ele_loss_irm, _ = \ - self._cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) list_domain_reg += list_1ele_loss_irm loss = torch.sum(torch.stack(list_domain_loss_erm)) + \ self.aconf.gamma_reg * torch.sum(torch.stack(list_domain_reg)) diff --git a/domainlab/models/a_model_classif.py b/domainlab/models/a_model_classif.py index 1f72eec0a..ea9ef5f47 100644 --- a/domainlab/models/a_model_classif.py +++ b/domainlab/models/a_model_classif.py @@ -243,3 +243,4 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): device = tensor_x.device bsize = tensor_x.shape[0] return [torch.zeros(bsize).to(device)], [0.0] + # return [], [] From 64bcc9cdd2021eea749e8685af69d2fac384b96b Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 9 Oct 2024 14:45:26 +0200 Subject: [PATCH 06/40] trainer behaves like model, now decoratte's cal_loss has to be changed --- domainlab/algos/trainers/a_trainer.py | 34 +++++++++++++++++++++-- domainlab/algos/trainers/train_basic.py | 5 ++-- domainlab/algos/trainers/train_dial.py | 8 ------ domainlab/algos/trainers/train_fbopt_b.py | 21 ++++++++------ 4 files changed, 47 insertions(+), 21 deletions(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 35a5b0fe2..35e3f5f21 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -97,6 +97,8 @@ def __init__(self, successor_node=None, extend=None): self._ma_iter = 0 # self.list_reg_over_task_ratio = None + # mhof + self.dict_multiplier = {} @property @@ -200,11 +202,18 @@ def before_tr(self): """ before training, probe model performance """ - self.cal_reg_loss_over_task_loss_ratio() + list_mu = self.cal_reg_loss_over_task_loss_ratio() + self.dict_multiplier = {'mu4regloss'+str(i): value for i, value in enumerate(list_mu)} + + @property + def list_str_multiplier_na(self): + list_str = list(self.dict_multiplier.keys()) + return list_str def cal_reg_loss_over_task_loss_ratio(self): list_accum_reg_loss = [] loss_task_agg = 0 + list_mu = None for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( self.loader_tr ): @@ -215,7 +224,7 @@ def cal_reg_loss_over_task_loss_ratio(self): tensor_y.to(self.device), tensor_d.to(self.device), ) - list_reg_loss_tensor, _ = \ + list_reg_loss_tensor, list_mu = \ self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) list_reg_loss_tensor = [torch.sum(tensor).detach().item() for tensor in list_reg_loss_tensor] @@ -232,6 +241,7 @@ def cal_reg_loss_over_task_loss_ratio(self): loss_task_agg += tensor_loss_task self.list_reg_over_task_ratio = [reg_loss / loss_task_agg for reg_loss in list_accum_reg_loss] + return list_mu def post_tr(self): """ @@ -327,3 +337,23 @@ def print_parameters(self): """ params = vars(self) print(f"Parameters of {type(self).__name__}: {params}") + + def hyper_init(self, functor_scheduler, trainer): + """ + initialize both trainer's multiplier and model's multiplier + """ + if not self.dict_multiplier: + raise RuntimeError("self.dict_multiplier empty!") + return functor_scheduler( + trainer=trainer, **self.dict_multiplier + ) + + def hyper_update(self, epoch, fun_scheduler): + """hyper_update. + + :param epoch: + :param fun_scheduler: + """ + dict_rst = fun_scheduler(epoch) + for key in self.dict_multiplier: + self.dict_multiplier[key] = dict_rst[key] diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index 100c274b4..d06952005 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -23,6 +23,7 @@ def before_tr(self): """ self.model.evaluate(self.loader_te, self.device) super().before_tr() + self.before_epoch() def before_epoch(self): """ @@ -94,7 +95,6 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others): list_reg_tr_batch, list_mu_tr = self.cal_reg_loss( tensor_x, tensor_y, tensor_d, others ) - list_mu_tr_normalized = list_mu_tr if self.list_reg_over_task_ratio: assert len(list_mu_tr) == len(self.list_reg_over_task_ratio) @@ -105,6 +105,7 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others): tensor_batch_reg_loss_penalized = self.model.list_inner_product( list_reg_tr_batch, list_mu_tr_normalized ) + assert len(tensor_batch_reg_loss_penalized.shape) == 1 loss_erm_agg = g_tensor_batch_agg(loss_task) loss_reg_penalized_agg = g_tensor_batch_agg(tensor_batch_reg_loss_penalized) @@ -112,4 +113,4 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others): self.model.multiplier4task_loss * loss_erm_agg + loss_reg_penalized_agg ) self.log_loss(list_reg_tr_batch, loss_task, loss_penalized) - return loss_penalized + return loss_penalized, list_reg_tr_batch, loss_erm_agg diff --git a/domainlab/algos/trainers/train_dial.py b/domainlab/algos/trainers/train_dial.py index dbcb50eae..4fe700f45 100644 --- a/domainlab/algos/trainers/train_dial.py +++ b/domainlab/algos/trainers/train_dial.py @@ -51,11 +51,3 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): tensor_x_batch_adv_no_grad = Variable(tensor_x_adv, requires_grad=False) loss_dial = self.model.cal_task_loss(tensor_x_batch_adv_no_grad, tensor_y) return [loss_dial], [get_gamma_reg(self.aconf, self.name)] - - def hyper_init(self, functor_scheduler, trainer): - """ - initialize both trainer's multiplier and model's multiplier - """ - fun_scheduler = super().hyper_init(functor_scheduler, trainer) - return fun_scheduler - # FIXME: register also the trainer hyperpars: return functor_scheduler(trainer=trainer, gamma_reg=self.gamma_reg) diff --git a/domainlab/algos/trainers/train_fbopt_b.py b/domainlab/algos/trainers/train_fbopt_b.py index e922a7521..43a5bf619 100644 --- a/domainlab/algos/trainers/train_fbopt_b.py +++ b/domainlab/algos/trainers/train_fbopt_b.py @@ -43,7 +43,7 @@ def set_scheduler(self, scheduler): this class name will be created inside model """ # model.hyper_init will register the hyper-parameters of the model to scheduler - self.hyper_scheduler = self.model.hyper_init(scheduler, trainer=self) + self.hyper_scheduler = self.decoratee.hyper_init(scheduler, trainer=self) def eval_r_loss(self): """ @@ -67,7 +67,9 @@ def eval_r_loss(self): vec_d.to(self.device), ) tuple_reg_loss = self.decoratee.cal_reg_loss(tensor_x, vec_y, vec_d, others) - p_loss, *_ = self.model.cal_loss(tensor_x, vec_y, vec_d, others) + p_loss, *_ = self.decoratee.cal_loss(tensor_x, vec_y, vec_d, others) + if p_loss.dim() > 0: + p_loss = p_loss.sum() # NOTE: first [0] extract the loss, second [0] get the list list_b_reg_loss = tuple_reg_loss[0] list_b_reg_loss_sumed = [ @@ -82,7 +84,7 @@ def eval_r_loss(self): ) # sum will kill the dimension of the mini batch epo_task_loss += b_task_loss - epo_p_loss += p_loss.sum().detach().item() + epo_p_loss += p_loss.detach().item() counter += 1.0 return ( list_divide(epo_reg_loss, counter), @@ -103,6 +105,8 @@ def before_batch(self, epoch, ind_batch): return super().after_batch(epoch, ind_batch) def before_tr(self): + if hasattr(self.decoratee, "before_tr"): + self.decoratee.before_tr() self.flag_setpoint_updated = False if self.aconf.force_feedforward: self.set_scheduler(scheduler=HyperSchedulerWarmupLinear) @@ -126,8 +130,7 @@ def before_tr(self): for ele in self.epo_reg_loss_tr ], self.epo_task_loss_tr, - ) # setpoing w.r.t. random initialization of neural network - # FIXME: check if self.epo_reg_loss_tr is zero!! + ) # setpoint w.r.t. random initialization of neural network self.hyper_scheduler.set_k_i_gain(self.epo_reg_loss_tr) @property @@ -135,7 +138,7 @@ def list_str_multiplier_na(self): """ return the name of multipliers """ - return self.model.list_str_multiplier_na + return self.decoratee.list_str_multiplier_na def tr_with_init_mu(self): """ @@ -147,7 +150,7 @@ def set_model_with_mu(self): """ set model multipliers """ - self.model.hyper_update( + self.decoratee.hyper_update( epoch=None, fun_scheduler=HyperSetter(self.hyper_scheduler.mmu) ) @@ -163,9 +166,9 @@ def tr_epoch(self, epoch, flag_info=False): miter=epoch, ) self.set_model_with_mu() - if hasattr(self.model, "dict_multiplier"): + if hasattr(self.decoratee, "dict_multiplier"): logger = Logger.get_logger() - logger.info(f"current multiplier: {self.model.dict_multiplier}") + logger.info(f"current multiplier: {self.decoratee.dict_multiplier}") if self._decoratee is not None: flag = self._decoratee.tr_epoch(epoch, self.flag_setpoint_updated) From 22f343cac977767ac1ef9165eb22fe284f71ae24 Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 9 Oct 2024 15:29:46 +0200 Subject: [PATCH 07/40] overwrite multiplier from scheduler(default static scheduler then no change) in cal_reg_loss --- domainlab/algos/trainers/a_trainer.py | 2 ++ domainlab/algos/trainers/train_basic.py | 2 +- domainlab/algos/trainers/train_irm.py | 30 ------------------------- domainlab/algos/trainers/zoo_trainer.py | 2 ++ 4 files changed, 5 insertions(+), 31 deletions(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 35e3f5f21..d53bac29f 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -312,6 +312,8 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): list_boolean_zero = [torch.all(torch.eq(list_loss_tensor[i], 0)).item() for i in range(len(list_mu))] list_loss_tensor = [list_loss_tensor[i] for (i, flag) in enumerate(list_boolean_zero) if not flag] list_mu = [list_mu[i] for (i, flag) in enumerate(list_boolean_zero) if not flag] + if self.dict_multiplier: + list_mu = list(self.dict_multiplier.values()) return list_loss_tensor, list_mu def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index d06952005..0e7faf8e4 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -80,7 +80,7 @@ def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch): tensor_d.to(self.device), ) self.optimizer.zero_grad() - loss = self.cal_loss(tensor_x, tensor_y, tensor_d, others) + loss, *_ = self.cal_loss(tensor_x, tensor_y, tensor_d, others) loss.backward() self.optimizer.step() self.after_batch(epoch, ind_batch) diff --git a/domainlab/algos/trainers/train_irm.py b/domainlab/algos/trainers/train_irm.py index 377759a2f..aaa9f2dd0 100644 --- a/domainlab/algos/trainers/train_irm.py +++ b/domainlab/algos/trainers/train_irm.py @@ -19,36 +19,6 @@ class TrainerIRM(TrainerBasic): For more details, see section 3.2 and Appendix D of : Arjovsky et al., “Invariant Risk Minimization.” """ - def tr_epoch(self, epoch, flag_info=False): - list_loaders = list(self.dict_loader_tr.values()) - loaders_zip = zip(*list_loaders) - self.model.train() - self.epo_loss_tr = 0 - - for ind_batch, tuple_data_domains_batch in enumerate(loaders_zip): - self.optimizer.zero_grad() - list_domain_loss_erm = [] - list_domain_reg = [] - for batch_domain_e in tuple_data_domains_batch: - tensor_x, tensor_y, tensor_d, *others = batch_domain_e - tensor_x, tensor_y, tensor_d = \ - tensor_x.to(self.device), tensor_y.to(self.device), \ - tensor_d.to(self.device) - list_domain_loss_erm.append( - self.model.cal_task_loss(tensor_x, tensor_y)) - list_1ele_loss_irm, _ = \ - self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) - list_domain_reg += list_1ele_loss_irm - loss = torch.sum(torch.stack(list_domain_loss_erm)) + \ - self.aconf.gamma_reg * torch.sum(torch.stack(list_domain_reg)) - loss.backward() - self.optimizer.step() - self.epo_loss_tr += loss.detach().item() - self.after_batch(epoch, ind_batch) - - flag_stop = self.observer.update(epoch, flag_info) # notify observer - return flag_stop - def _cal_phi(self, tensor_x): logits = self.model.cal_logit_y(tensor_x) return logits diff --git a/domainlab/algos/trainers/zoo_trainer.py b/domainlab/algos/trainers/zoo_trainer.py index 93b40acfe..dcf1785dc 100644 --- a/domainlab/algos/trainers/zoo_trainer.py +++ b/domainlab/algos/trainers/zoo_trainer.py @@ -11,6 +11,7 @@ from domainlab.algos.trainers.train_mldg import TrainerMLDG from domainlab.algos.trainers.train_fishr import TrainerFishr from domainlab.algos.trainers.train_irm import TrainerIRM +from domainlab.algos.trainers.train_irm_sep_dom import TrainerIRMSepDom from domainlab.algos.trainers.train_causIRL import TrainerCausalIRL from domainlab.algos.trainers.train_coral import TrainerCoral @@ -57,6 +58,7 @@ def __call__(self, lst_candidates=None, default=None, lst_excludes=None): chain = TrainerMLDG(chain) chain = TrainerFishr(chain) chain = TrainerIRM(chain) + chain = TrainerIRMSepDom(chain) chain = TrainerHyperScheduler(chain) chain = TrainerFbOpt(chain) chain = TrainerCausalIRL(chain) From 23607fb430e166fdada5e8b931e66be2117ab63a Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 9 Oct 2024 15:30:13 +0200 Subject: [PATCH 08/40] per domain irm to separate file --- domainlab/algos/trainers/train_irm_sep_dom.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 domainlab/algos/trainers/train_irm_sep_dom.py diff --git a/domainlab/algos/trainers/train_irm_sep_dom.py b/domainlab/algos/trainers/train_irm_sep_dom.py new file mode 100644 index 000000000..94d3bca79 --- /dev/null +++ b/domainlab/algos/trainers/train_irm_sep_dom.py @@ -0,0 +1,39 @@ +""" +use random start to generate adversarial images +""" +import torch +from torch import autograd +from torch.nn import functional as F +from domainlab.algos.trainers.train_irm import TrainerIRM + + +class TrainerIRMSepDom(TrainerIRM): + def tr_epoch(self, epoch, flag_info=False): + list_loaders = list(self.dict_loader_tr.values()) + loaders_zip = zip(*list_loaders) + self.model.train() + self.epo_loss_tr = 0 + + for ind_batch, tuple_data_domains_batch in enumerate(loaders_zip): + self.optimizer.zero_grad() + list_domain_loss_erm = [] + list_domain_reg = [] + for batch_domain_e in tuple_data_domains_batch: + tensor_x, tensor_y, tensor_d, *others = batch_domain_e + tensor_x, tensor_y, tensor_d = \ + tensor_x.to(self.device), tensor_y.to(self.device), \ + tensor_d.to(self.device) + list_domain_loss_erm.append( + self.model.cal_task_loss(tensor_x, tensor_y)) + list_1ele_loss_irm, _ = \ + self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + list_domain_reg += list_1ele_loss_irm + loss = torch.sum(torch.stack(list_domain_loss_erm)) + \ + self.aconf.gamma_reg * torch.sum(torch.stack(list_domain_reg)) + loss.backward() + self.optimizer.step() + self.epo_loss_tr += loss.detach().item() + self.after_batch(epoch, ind_batch) + + flag_stop = self.observer.update(epoch, flag_info) # notify observer + return flag_stop From 63ad47b89bd771ffc8f98ecb7001fdff3964ecc3 Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 9 Oct 2024 18:13:38 +0200 Subject: [PATCH 09/40] dial mhof yaml --- examples/benchmark/aistat_dial_erm_mhof.yaml | 64 ++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 examples/benchmark/aistat_dial_erm_mhof.yaml diff --git a/examples/benchmark/aistat_dial_erm_mhof.yaml b/examples/benchmark/aistat_dial_erm_mhof.yaml new file mode 100644 index 000000000..50e801766 --- /dev/null +++ b/examples/benchmark/aistat_dial_erm_mhof.yaml @@ -0,0 +1,64 @@ +mode: grid + +output_dir: zoutput/benchmarks/benchmark_mhof_dial + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 200 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + ini_setpoint_ratio: + min: 0.5 + max: 0.99 + num: 2 + step: 0.05 + distribution: uniform + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 3 + distribution: uniform + + gamma_reg: + min: 0.01 + max: 1e4 + num: 4 + distribution: loguniform + + +# Test fbopt with different hyperparameter configurations + +fbopt_dial_erm: + model: erm + trainer: fbopt_dial + shared: + - ini_setpoint_ratio + - k_i_gain_ratio + +dial_erm: + model: erm + trainer: dial + shared: + - gamma_reg From 8ba7e110430938e35e572d8a79106d0196c91e59 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Wed, 9 Oct 2024 18:34:39 +0200 Subject: [PATCH 10/40] number of bathces to estimate ratio --- examples/benchmark/aistat_irm_erm_mhof.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/benchmark/aistat_irm_erm_mhof.yaml b/examples/benchmark/aistat_irm_erm_mhof.yaml index be8a3c480..eceef6919 100644 --- a/examples/benchmark/aistat_irm_erm_mhof.yaml +++ b/examples/benchmark/aistat_irm_erm_mhof.yaml @@ -28,6 +28,14 @@ domainlab_args: Shared params: + nb4reg_over_task_ratio: + distribution: categorical # name of the distribution + datatype: int + values: # concrete values to choose from + - 0 + - 1 + - 100 + ini_setpoint_ratio: min: 0.5 max: 0.99 From 60734030a2e6af044cf2f88de908c5d9d9480267 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Wed, 9 Oct 2024 18:54:46 +0200 Subject: [PATCH 11/40] Update aistat_irm_erm_only.yaml --- examples/benchmark/aistat_irm_erm_only.yaml | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/benchmark/aistat_irm_erm_only.yaml b/examples/benchmark/aistat_irm_erm_only.yaml index 299ec528c..a194c4c63 100644 --- a/examples/benchmark/aistat_irm_erm_only.yaml +++ b/examples/benchmark/aistat_irm_erm_only.yaml @@ -14,7 +14,7 @@ domainlab_args: dmem: False lr: 5e-5 epos: 500 - epos_min: 200 + epos_min: 10 es: 5 bs: 32 san_check: False @@ -28,6 +28,14 @@ domainlab_args: Shared params: + nb4reg_over_task_ratio: + distribution: categorical # name of the distribution + datatype: int + values: # concrete values to choose from + - 0 + - 1 + - 100 + ini_setpoint_ratio: min: 0.5 max: 0.99 @@ -56,3 +64,4 @@ irm_erm: trainer: irm shared: - gamma_reg + - nb4reg_over_task_ratio From 8bd1163e8e0b246d40d495109c6e9b6cf64b08c8 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Wed, 9 Oct 2024 18:56:49 +0200 Subject: [PATCH 12/40] Update aistat_irm_erm_mhof.yaml --- examples/benchmark/aistat_irm_erm_mhof.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/benchmark/aistat_irm_erm_mhof.yaml b/examples/benchmark/aistat_irm_erm_mhof.yaml index eceef6919..dae4334b7 100644 --- a/examples/benchmark/aistat_irm_erm_mhof.yaml +++ b/examples/benchmark/aistat_irm_erm_mhof.yaml @@ -14,7 +14,7 @@ domainlab_args: dmem: False lr: 5e-5 epos: 500 - epos_min: 200 + epos_min: 20 es: 5 bs: 32 san_check: False @@ -64,9 +64,11 @@ fbopt_irm_erm: shared: - ini_setpoint_ratio - k_i_gain_ratio + - nb4reg_over_task_ratio irm_erm: model: erm trainer: irm shared: - gamma_reg + - nb4reg_over_task_ratio From 7667cbc902d0ef8a0c9797cae96181b32e95d7a8 Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 9 Oct 2024 19:02:20 +0200 Subject: [PATCH 13/40] . --- examples/benchmark/aistat_irm_erm_mhof.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmark/aistat_irm_erm_mhof.yaml b/examples/benchmark/aistat_irm_erm_mhof.yaml index dae4334b7..003b7a616 100644 --- a/examples/benchmark/aistat_irm_erm_mhof.yaml +++ b/examples/benchmark/aistat_irm_erm_mhof.yaml @@ -1,6 +1,6 @@ mode: grid -output_dir: zoutput/benchmarks/benchmark_fbopt_fishr_erm_pacs +output_dir: zoutput/benchmarks/benchmark_mhof_irm_erm_pacs sampling_seed: 0 startseed: 0 From b64e5d0f3abcab3f0400c3cbd09eb36f8dcd5e49 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Thu, 10 Oct 2024 08:37:33 +0200 Subject: [PATCH 14/40] Update aistat_irm_erm_only.yaml --- examples/benchmark/aistat_irm_erm_only.yaml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/benchmark/aistat_irm_erm_only.yaml b/examples/benchmark/aistat_irm_erm_only.yaml index a194c4c63..2a9c011e8 100644 --- a/examples/benchmark/aistat_irm_erm_only.yaml +++ b/examples/benchmark/aistat_irm_erm_only.yaml @@ -65,3 +65,10 @@ irm_erm: shared: - gamma_reg - nb4reg_over_task_ratio + +irm_erm_dep_dom: + model: erm + trainer: irmsepdom + shared: + - gamma_reg + - nb4reg_over_task_ratio From ba68e274345e9981438d1611495aba224fd8f31b Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 10 Oct 2024 09:52:45 +0200 Subject: [PATCH 15/40] fix num_batches for loss ratio estimate --- domainlab/algos/trainers/a_trainer.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index d9c4a13f5..99460afc9 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -221,8 +221,6 @@ def cal_reg_loss_over_task_loss_ratio(self): for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( self.loader_tr ): - if ind_batch >= self.aconf.nb4reg_over_task_ratio: - return tensor_x, tensor_y, tensor_d = ( tensor_x.to(self.device), tensor_y.to(self.device), @@ -230,6 +228,10 @@ def cal_reg_loss_over_task_loss_ratio(self): ) list_reg_loss_tensor, list_mu = \ self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + + if ind_batch >= self.aconf.nb4reg_over_task_ratio: + return list_mu + list_reg_loss_tensor = [torch.sum(tensor).detach().item() for tensor in list_reg_loss_tensor] if ind_batch == 0: @@ -313,9 +315,12 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): list_reg_loss_trainer_tensor list_mu = list_mu_model + list_mu_trainer # ERM return a tensor of all zeros, delete here - list_boolean_zero = [torch.all(torch.eq(list_loss_tensor[i], 0)).item() for i in range(len(list_mu))] - list_loss_tensor = [list_loss_tensor[i] for (i, flag) in enumerate(list_boolean_zero) if not flag] - list_mu = [list_mu[i] for (i, flag) in enumerate(list_boolean_zero) if not flag] + if len(list_mu) > 1: + list_boolean_zero = [torch.all(torch.eq(list_loss_tensor[i], 0)).item() + for i in range(len(list_mu))] + list_loss_tensor = [list_loss_tensor[i] for (i, flag) in + enumerate(list_boolean_zero) if not flag] + list_mu = [list_mu[i] for (i, flag) in enumerate(list_boolean_zero) if not flag] if self.dict_multiplier: list_mu = list(self.dict_multiplier.values()) return list_loss_tensor, list_mu From c09879d5970f541cca435ef3693309832b5e5ca1 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 10 Oct 2024 09:53:58 +0200 Subject: [PATCH 16/40] . --- examples/benchmark/aistat_irm_erm_mhof.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/benchmark/aistat_irm_erm_mhof.yaml b/examples/benchmark/aistat_irm_erm_mhof.yaml index 003b7a616..326a82fc2 100644 --- a/examples/benchmark/aistat_irm_erm_mhof.yaml +++ b/examples/benchmark/aistat_irm_erm_mhof.yaml @@ -72,3 +72,11 @@ irm_erm: shared: - gamma_reg - nb4reg_over_task_ratio + + +irm_erm_dep_dom: + model: erm + trainer: irmsepdom + shared: + - gamma_reg + - nb4reg_over_task_ratio From 5585b4a2a6c02f2debc0eaf4c11ad43b962f13d6 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 10 Oct 2024 12:06:32 +0200 Subject: [PATCH 17/40] use correct hyper range --- examples/benchmark/aistat_irm_erm_mhof.yaml | 36 ++++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/examples/benchmark/aistat_irm_erm_mhof.yaml b/examples/benchmark/aistat_irm_erm_mhof.yaml index 326a82fc2..0e36394b8 100644 --- a/examples/benchmark/aistat_irm_erm_mhof.yaml +++ b/examples/benchmark/aistat_irm_erm_mhof.yaml @@ -35,13 +35,22 @@ Shared params: - 0 - 1 - 100 - - ini_setpoint_ratio: - min: 0.5 - max: 0.99 - num: 2 - step: 0.05 - distribution: uniform + + mu_init: + distribution: categorical + datatype: float + values: + - 1 + - 0.001 + + mu_clip: + distribution: categorical + datatype: float + values: + - 1 + - 10 + - 100 + - 1000 k_i_gain_ratio: min: 0.1 @@ -61,10 +70,13 @@ Shared params: fbopt_irm_erm: model: erm trainer: fbopt_irm + ini_setpoint_ratio: 0.99 + shared: - - ini_setpoint_ratio - k_i_gain_ratio - nb4reg_over_task_ratio + - mu_clip + - mu_init irm_erm: model: erm @@ -72,11 +84,3 @@ irm_erm: shared: - gamma_reg - nb4reg_over_task_ratio - - -irm_erm_dep_dom: - model: erm - trainer: irmsepdom - shared: - - gamma_reg - - nb4reg_over_task_ratio From d176e1a110284b774e745201f00876b48c3ab8e4 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Thu, 10 Oct 2024 13:34:36 +0200 Subject: [PATCH 18/40] Update aistat_irm_erm_mhof.yaml --- examples/benchmark/aistat_irm_erm_mhof.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmark/aistat_irm_erm_mhof.yaml b/examples/benchmark/aistat_irm_erm_mhof.yaml index 0e36394b8..d33769bbe 100644 --- a/examples/benchmark/aistat_irm_erm_mhof.yaml +++ b/examples/benchmark/aistat_irm_erm_mhof.yaml @@ -14,7 +14,7 @@ domainlab_args: dmem: False lr: 5e-5 epos: 500 - epos_min: 20 + epos_min: 10 es: 5 bs: 32 san_check: False From 3a4226e68cc3a818069691bcc1dae896613257dd Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 10 Oct 2024 13:37:22 +0200 Subject: [PATCH 19/40] add unit test --- tests/test_irm.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_irm.py b/tests/test_irm.py index 235b9e4ce..42b12811f 100644 --- a/tests/test_irm.py +++ b/tests/test_irm.py @@ -12,6 +12,16 @@ def test_irm(): --trainer=irm --nname=alexnet" utils_test_algo(args) +def test_irm_sepdom(): + """ + train with Invariant Risk Minimization + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \ + --trainer=irmsepdom --nname=alexnet" + utils_test_algo(args) + + + def test_irm_scheduler(): """ From 011cccbbf2ba3f76f6ea23236102c15ae958f356 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Thu, 10 Oct 2024 14:45:58 +0200 Subject: [PATCH 20/40] Update aistat_irm_erm_mhof.yaml --- examples/benchmark/aistat_irm_erm_mhof.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/benchmark/aistat_irm_erm_mhof.yaml b/examples/benchmark/aistat_irm_erm_mhof.yaml index d33769bbe..9859e4281 100644 --- a/examples/benchmark/aistat_irm_erm_mhof.yaml +++ b/examples/benchmark/aistat_irm_erm_mhof.yaml @@ -40,7 +40,7 @@ Shared params: distribution: categorical datatype: float values: - - 1 + - 0.000001 - 0.001 mu_clip: @@ -63,6 +63,7 @@ Shared params: max: 1e4 num: 4 distribution: loguniform + # 1778 is the largest gamma_reg using this sampling # Test fbopt with different hyperparameter configurations From ae03219338d68e68ca5c5ab92b03b30e4ee36462 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Thu, 10 Oct 2024 15:35:50 +0200 Subject: [PATCH 21/40] Update and rename pacs_fbopt_dial_diva.yaml to aistat_pacs_mhof_dial_diva.yaml --- ...a.yaml => aistat_pacs_mhof_dial_diva.yaml} | 26 +++++++------------ 1 file changed, 9 insertions(+), 17 deletions(-) rename examples/benchmark/{pacs_fbopt_dial_diva.yaml => aistat_pacs_mhof_dial_diva.yaml} (81%) diff --git a/examples/benchmark/pacs_fbopt_dial_diva.yaml b/examples/benchmark/aistat_pacs_mhof_dial_diva.yaml similarity index 81% rename from examples/benchmark/pacs_fbopt_dial_diva.yaml rename to examples/benchmark/aistat_pacs_mhof_dial_diva.yaml index ca2cf3921..ba0753e53 100644 --- a/examples/benchmark/pacs_fbopt_dial_diva.yaml +++ b/examples/benchmark/aistat_pacs_mhof_dial_diva.yaml @@ -14,7 +14,7 @@ domainlab_args: dmem: False lr: 5e-5 epos: 500 - epos_min: 200 + epos_min: 20 es: 5 bs: 16 san_check: False @@ -28,13 +28,6 @@ domainlab_args: Shared params: - ini_setpoint_ratio: - min: 0.9 - max: 0.99 - num: 2 - step: 0.05 - distribution: uniform - k_i_gain_ratio: min: 0.01 max: 0.90 @@ -78,16 +71,15 @@ dial_fbopt: trainer: fbopt_dial gamma_y: 1.0 shared: - - ini_setpoint_ratio - k_i_gain_ratio - dial_lr - dial_epsilon -# dial: -# model: diva -# trainer: dial -# shared: -# - dial_lr -# - dial_epsilon -# - gamma_y -# - gamma_d +dial_diva: + model: diva + trainer: dial + shared: + - dial_lr + - dial_epsilon + - gamma_y + - gamma_d From 3166eb24c4b14fe9bf1d7575408eb3a8899914bc Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 10 Oct 2024 16:21:20 +0200 Subject: [PATCH 22/40] dial --- examples/benchmark/aistat_dial_erm_mhof.yaml | 33 +++++++++++++++----- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/examples/benchmark/aistat_dial_erm_mhof.yaml b/examples/benchmark/aistat_dial_erm_mhof.yaml index 50e801766..ad9304402 100644 --- a/examples/benchmark/aistat_dial_erm_mhof.yaml +++ b/examples/benchmark/aistat_dial_erm_mhof.yaml @@ -28,12 +28,28 @@ domainlab_args: Shared params: - ini_setpoint_ratio: - min: 0.5 - max: 0.99 - num: 2 - step: 0.05 - distribution: uniform + nb4reg_over_task_ratio: + distribution: categorical # name of the distribution + datatype: int + values: # concrete values to choose from + - 0 + - 100 + + mu_init: + distribution: categorical + datatype: float + values: + - 0.000001 + - 0.001 + + mu_clip: + distribution: categorical + datatype: float + values: + - 1 + - 10 + - 100 + - 1000 k_i_gain_ratio: min: 0.1 @@ -53,12 +69,15 @@ Shared params: fbopt_dial_erm: model: erm trainer: fbopt_dial + nb4reg_over_task_ratio: 0 shared: - - ini_setpoint_ratio - k_i_gain_ratio + - mu_init + - mu_clip dial_erm: model: erm + nb4reg_over_task_ratio: 0 trainer: dial shared: - gamma_reg From d8cee2d67a40fcc186a84358e8b131b39567f8aa Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 11 Oct 2024 08:53:32 +0200 Subject: [PATCH 23/40] fix partially issue #777 --- domainlab/algos/trainers/args_fbopt.py | 31 +++++++++++++++++--------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/domainlab/algos/trainers/args_fbopt.py b/domainlab/algos/trainers/args_fbopt.py index 53719e05f..38ff70b55 100644 --- a/domainlab/algos/trainers/args_fbopt.py +++ b/domainlab/algos/trainers/args_fbopt.py @@ -9,7 +9,9 @@ def add_args2parser_fbopt(parser): """ parser.add_argument( - "--k_i_gain", type=float, default=0.001, help="PID control gain for integrator" + "--k_i_gain", type=float, default=0.001, + help="PID control gain for integrator, if k_i_gain_ratio is not None, \ + then this value will be overwriten" ) parser.add_argument( @@ -29,33 +31,39 @@ def add_args2parser_fbopt(parser): ) parser.add_argument( - "--mu_init", type=float, default=0.001, help="initial beta for multiplication" + "--mu_init", type=float, default=0.001, + help="initial value for each component of the multiplier vector" ) parser.add_argument( - "--coeff_ma", type=float, default=0.5, help="exponential moving average" + "--coeff_ma", type=float, default=0.5, + help="exponential moving average" ) parser.add_argument( "--coeff_ma_output_state", type=float, default=0.1, - help="state exponential moving average of \ - reguarlization loss", + help="output (reguarization loss) exponential moving average", ) parser.add_argument( "--coeff_ma_setpoint", type=float, default=0.9, - help="setpoint average coeff for previous setpoint", + help="setpoint average (coeff for previous setpoint)", ) parser.add_argument( "--exp_shoulder_clip", type=float, default=5, - help="clip before exponential operation", + help="clip delta(control error): \ + R(reg-loss)-b(setpoint) before exponential operation: \ + exp[clip(R-b, exp_shoulder_clip)].\ + exponential magnifies control error, so this argument \ + defines the maximum rate of change of multipliers \ + exp(5)=148", ) parser.add_argument( @@ -77,8 +85,8 @@ def add_args2parser_fbopt(parser): "--force_setpoint_change_once", action="store_true", default=False, - help="train until the setpoint changed at least once \ - up to maximum epos specified", + help="continue trainiing until the setpoint changed at least once: \ + up to maximum epos specified", ) parser.add_argument( @@ -102,6 +110,7 @@ def add_args2parser_fbopt(parser): help="disable setpoint update", ) + # FIXME: change arguments from str to boolean parser.add_argument( "--overshoot_rewind", type=str, @@ -113,14 +122,14 @@ def add_args2parser_fbopt(parser): "--setpoint_rewind", type=str, default="no", - help="setpoing_rewind, for benchmark, use yes or no", + help="rewind setpoint, for benchmark, use yes or no", ) parser.add_argument( "--str_diva_multiplier_type", type=str, default="gammad_recon", - help="which penalty to tune", + help="which penalty to tune, only useful to DIVA model", ) return parser From 98846403c582c9d3443e339a303d434b8d23e998 Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 11 Oct 2024 11:13:39 +0200 Subject: [PATCH 24/40] detailed doc for mhof args, change str args to boolean, fix issue #777 --- domainlab/algos/trainers/args_fbopt.py | 49 ++++++++++++------- .../algos/trainers/fbopt_mu_controller.py | 7 +-- .../algos/trainers/fbopt_setpoint_ada.py | 6 +-- examples/benchmark/aistat_dial_erm_mhof.yaml | 1 + examples/benchmark/aistat_irm_erm_mhof.yaml | 10 +++- 5 files changed, 48 insertions(+), 25 deletions(-) diff --git a/domainlab/algos/trainers/args_fbopt.py b/domainlab/algos/trainers/args_fbopt.py index 38ff70b55..c738d76ff 100644 --- a/domainlab/algos/trainers/args_fbopt.py +++ b/domainlab/algos/trainers/args_fbopt.py @@ -11,23 +11,31 @@ def add_args2parser_fbopt(parser): parser.add_argument( "--k_i_gain", type=float, default=0.001, help="PID control gain for integrator, if k_i_gain_ratio is not None, \ - then this value will be overwriten" + then this value will be overwriten, see doc for k_i_gain_ratio" ) parser.add_argument( "--k_i_gain_ratio", type=float, default=None, - help="set k_i_gain to be ratio of \ - initial saturation k_i_gain", + help="set k_i_gain to be ratio of initial saturation k_i_gain \ + which K_I * delta = exp_shoulder_clip(saturation value), solve \ + for K_I, where delta = reg loss - setpoint. \ + Now independent of the scale of delta, the K_I gain will be set so \ + that the multiplier will be increased at a rate defined by \ + exp_shoulder_clip", ) parser.add_argument( - "--mu_clip", type=float, default=1e4, help="maximum value of mu" + "--mu_clip", type=float, default=1e4, + help="maximum value of mu: mu_clip should be large enough so that the \ + regularization loss as penalty can be weighed superior enough to \ + decrease." ) parser.add_argument( - "--mu_min", type=float, default=1e-6, help="minimum value of mu" + "--mu_min", type=float, default=1e-6, help="minimum value of mu, mu \ + can not be negative" ) parser.add_argument( @@ -37,14 +45,17 @@ def add_args2parser_fbopt(parser): parser.add_argument( "--coeff_ma", type=float, default=0.5, - help="exponential moving average" + help="exponential moving average of delta \ + (reg minus setpoint as control error): \ + move_ave=move_ave + coeff*delta(current value)" ) parser.add_argument( "--coeff_ma_output_state", type=float, default=0.1, - help="output (reguarization loss) exponential moving average", + help="output (reguarization loss) exponential moving average \ + move_ave=move_ave*coeef + reg(current value)", ) parser.add_argument( @@ -63,7 +74,8 @@ def add_args2parser_fbopt(parser): exp[clip(R-b, exp_shoulder_clip)].\ exponential magnifies control error, so this argument \ defines the maximum rate of change of multipliers \ - exp(5)=148", + exp(5)=148, exp_shoulder_clip should not be too big, \ + if exp_shoulder_clip is small, then more like exterior point method", ) parser.add_argument( @@ -71,7 +83,7 @@ def add_args2parser_fbopt(parser): type=float, default=0.99, help="before training start, evaluate reg loss, \ - setpoint will be 0.9 of this loss", + setpoint will be 0.9 of this loss", ) parser.add_argument( @@ -85,7 +97,7 @@ def add_args2parser_fbopt(parser): "--force_setpoint_change_once", action="store_true", default=False, - help="continue trainiing until the setpoint changed at least once: \ + help="continue training until the setpoint changed at least once: \ up to maximum epos specified", ) @@ -110,21 +122,22 @@ def add_args2parser_fbopt(parser): help="disable setpoint update", ) - # FIXME: change arguments from str to boolean parser.add_argument( - "--overshoot_rewind", - type=str, - default="yes", - help="overshoot_rewind, for benchmark, use yes or no", + "--no_overshoot_rewind", + action="store_true", + default=False, + help="disable overshoot rewind: when reg loss satisfies setpoint \ + already, then set activation=K_I*delta = 0", ) parser.add_argument( "--setpoint_rewind", - type=str, - default="no", - help="rewind setpoint, for benchmark, use yes or no", + action="store_true", + default=False, + help="rewind setpoint", ) + # this arg is only used when model is set to be "diva" parser.add_argument( "--str_diva_multiplier_type", type=str, diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py index 5c635e147..bfdccbcbd 100644 --- a/domainlab/algos/trainers/fbopt_mu_controller.py +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -53,13 +53,13 @@ def __init__(self, trainer, **kwargs): self.k_i_control = [trainer.aconf.k_i_gain for i in range(len(self.mmu))] self.k_i_gain_ratio = trainer.aconf.k_i_gain_ratio - self.overshoot_rewind = trainer.aconf.overshoot_rewind == "yes" + self.overshoot_rewind = not trainer.aconf.no_overshoot_rewind self.delta_epsilon_r = None # NOTE: this value will be set according to initial evaluation of # neural network self.activation_clip = trainer.aconf.exp_shoulder_clip - self.coeff_ma = trainer.aconf.coeff_ma + self.coeff4newval_ma_delta = trainer.aconf.coeff_ma # NOTE: # print(copy.deepcopy(self.model)) # TypeError: cannot pickle '_thread.lock' object @@ -125,7 +125,8 @@ def cal_delta4control(self, list1, list_setpoint): # self.delta_epsilon_r is the previous time step. # delta_epsilon_r is the current time step self.delta_epsilon_r = self.cal_delta_integration( - self.delta_epsilon_r, delta_epsilon_r, self.coeff_ma + self.delta_epsilon_r, delta_epsilon_r, + self.coeff4newval_ma_delta ) def cal_delta_integration(self, list_old, list_new, coeff): diff --git a/domainlab/algos/trainers/fbopt_setpoint_ada.py b/domainlab/algos/trainers/fbopt_setpoint_ada.py index c3c0193ce..d243fc5e8 100644 --- a/domainlab/algos/trainers/fbopt_setpoint_ada.py +++ b/domainlab/algos/trainers/fbopt_setpoint_ada.py @@ -81,7 +81,7 @@ def __init__(self, host): self.counter = None self.epo_ma = None self.ref = None - self.coeff_ma = 0.5 + self.coeff_ma_setpoint_rewinder = 0.5 self.setpoint_rewind = host.flag_setpoint_rewind def reset(self, epo_reg_loss): @@ -98,7 +98,7 @@ def observe(self, epo_reg_loss): """ if self.ref is None: self.reset(epo_reg_loss) - self.epo_ma = list_ma(self.epo_ma, epo_reg_loss, self.coeff_ma) + self.epo_ma = list_ma(self.epo_ma, epo_reg_loss, self.coeff_ma_setpoint_rewinder) list_comparison_increase = [a < b for a, b in zip(self.ref, self.epo_ma)] list_comparison_above_setpoint = [ a < b for a, b in zip(self.host.setpoint4R, self.epo_ma) @@ -146,7 +146,7 @@ def __init__(self, state=None, args=None): else: state = DominateAllComponent() self.transition_to(state) - self.flag_setpoint_rewind = args.setpoint_rewind == "yes" + self.flag_setpoint_rewind = args.setpoint_rewind self.setpoint_rewinder = SetpointRewinder(self) self.state_task_loss = 0.0 self.state_epo_reg_loss = [ diff --git a/examples/benchmark/aistat_dial_erm_mhof.yaml b/examples/benchmark/aistat_dial_erm_mhof.yaml index ad9304402..ab530d72f 100644 --- a/examples/benchmark/aistat_dial_erm_mhof.yaml +++ b/examples/benchmark/aistat_dial_erm_mhof.yaml @@ -70,6 +70,7 @@ fbopt_dial_erm: model: erm trainer: fbopt_dial nb4reg_over_task_ratio: 0 + force_setpoint_change_once: True shared: - k_i_gain_ratio - mu_init diff --git a/examples/benchmark/aistat_irm_erm_mhof.yaml b/examples/benchmark/aistat_irm_erm_mhof.yaml index 9859e4281..f4a508896 100644 --- a/examples/benchmark/aistat_irm_erm_mhof.yaml +++ b/examples/benchmark/aistat_irm_erm_mhof.yaml @@ -52,6 +52,14 @@ Shared params: - 100 - 1000 + exp_shoulder_clip: + distribution: categorical + datatype: float + values: + - 1 + - 2 + - 5 + k_i_gain_ratio: min: 0.1 max: 1 @@ -72,7 +80,7 @@ fbopt_irm_erm: model: erm trainer: fbopt_irm ini_setpoint_ratio: 0.99 - + force_setpoint_change_once: True shared: - k_i_gain_ratio - nb4reg_over_task_ratio From cd1591581479d408bdf6e0788b655fab33d2d829 Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 11 Oct 2024 12:00:06 +0200 Subject: [PATCH 25/40] new irm yaml file --- ...rm_only.yaml => aistat_irm_erm_mhof2.yaml} | 69 +++++++++++++------ 1 file changed, 47 insertions(+), 22 deletions(-) rename examples/benchmark/{aistat_irm_erm_only.yaml => aistat_irm_erm_mhof2.yaml} (51%) diff --git a/examples/benchmark/aistat_irm_erm_only.yaml b/examples/benchmark/aistat_irm_erm_mhof2.yaml similarity index 51% rename from examples/benchmark/aistat_irm_erm_only.yaml rename to examples/benchmark/aistat_irm_erm_mhof2.yaml index 2a9c011e8..d9f8e4037 100644 --- a/examples/benchmark/aistat_irm_erm_only.yaml +++ b/examples/benchmark/aistat_irm_erm_mhof2.yaml @@ -1,6 +1,6 @@ mode: grid -output_dir: zoutput/benchmarks/benchmark_irm_erm +output_dir: zoutput/benchmarks/benchmark_mhof_irm_erm_pacs sampling_seed: 0 startseed: 0 @@ -13,7 +13,7 @@ domainlab_args: tpath: examples/tasks/task_pacs_aug.py dmem: False lr: 5e-5 - epos: 500 + epos: 100 epos_min: 10 es: 5 bs: 32 @@ -33,42 +33,67 @@ Shared params: datatype: int values: # concrete values to choose from - 0 + - 100 + + mu_init: + distribution: categorical + datatype: float + values: + - 0.000001 + - 0.001 + + mu_clip: + distribution: categorical + datatype: float + values: - 1 + - 10 - 100 - - ini_setpoint_ratio: - min: 0.5 - max: 0.99 - num: 2 - step: 0.05 - distribution: uniform + - 1000 + + exp_shoulder_clip: + distribution: categorical + datatype: float + values: + - 1 + - 5 - k_i_gain: - min: 0.0001 - max: 0.01 + k_i_gain_ratio: + min: 0.1 + max: 1 num: 2 - step: 0.0001 distribution: uniform gamma_reg: - min: 0.01 - max: 1e4 - num: 10 - distribution: loguniform + distribution: categorical + datatype: float + values: + - 0.01 + - 0.1 + - 1 + - 10 + - 100 + + # 1778 is the largest gamma_reg using this sampling # Test fbopt with different hyperparameter configurations -irm_erm: +fbopt_irm_erm: model: erm - trainer: irm + trainer: fbopt_irm + ini_setpoint_ratio: 0.99 + force_setpoint_change_once: True + mu_clip: 10000 shared: - - gamma_reg + - k_i_gain_ratio - nb4reg_over_task_ratio + - mu_init + - exp_shoulder_clip -irm_erm_dep_dom: +irm_erm: model: erm - trainer: irmsepdom + trainer: irm shared: - gamma_reg - nb4reg_over_task_ratio From 273422bf1fdb8fd13422df98ddd7dbe3bfe4bed6 Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 11 Oct 2024 12:17:32 +0200 Subject: [PATCH 26/40] . --- a_test_mhof_irm.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/a_test_mhof_irm.sh b/a_test_mhof_irm.sh index 610d3a606..f3aaf549d 100644 --- a/a_test_mhof_irm.sh +++ b/a_test_mhof_irm.sh @@ -1 +1 @@ -python main_out.py --te_d=0 --task=mnistcolor10 --model=erm --trainer=fbopt_irm --nname=conv_bn_pool_2 --k_i_gain_ratio=0.5 +python main_out.py --te_d=0 --task=mnistcolor10 --model=erm --trainer=fbopt_irm --nname=conv_bn_pool_2 --k_i_gain_ratio=0.5 --force_setpoint_change_once From dbbda55ea1c3708524511088814ec51855ff8829 Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 11 Oct 2024 13:07:41 +0200 Subject: [PATCH 27/40] take square of irm loss, copy reg loss from decoratee to fbopt --- a_test_mhof_irm.sh | 2 +- domainlab/algos/trainers/train_fbopt_b.py | 8 ++++++++ domainlab/algos/trainers/train_irm.py | 2 ++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/a_test_mhof_irm.sh b/a_test_mhof_irm.sh index f3aaf549d..85b9445cc 100644 --- a/a_test_mhof_irm.sh +++ b/a_test_mhof_irm.sh @@ -1 +1 @@ -python main_out.py --te_d=0 --task=mnistcolor10 --model=erm --trainer=fbopt_irm --nname=conv_bn_pool_2 --k_i_gain_ratio=0.5 --force_setpoint_change_once +python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --trainer=fbopt_irm --nname=conv_bn_pool_2 --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=500 diff --git a/domainlab/algos/trainers/train_fbopt_b.py b/domainlab/algos/trainers/train_fbopt_b.py index 43a5bf619..e786579c2 100644 --- a/domainlab/algos/trainers/train_fbopt_b.py +++ b/domainlab/algos/trainers/train_fbopt_b.py @@ -122,6 +122,7 @@ def before_tr(self): self.epo_task_loss_tr, self.epo_loss_tr, ) = self.eval_r_loss() + self.hyper_scheduler.set_setpoint( [ ele * self.aconf.ini_setpoint_ratio @@ -172,6 +173,13 @@ def tr_epoch(self, epoch, flag_info=False): if self._decoratee is not None: flag = self._decoratee.tr_epoch(epoch, self.flag_setpoint_updated) + # self._decoratee.tr_epoch here will call + # self._decoratee.after_epoch to log the losses, but it only sotre + # the value into self._decoratee, + # so we have to mannually copy the value here + self.epo_loss_tr = self._decoratee.epo_loss_tr + self.epo_reg_loss_tr = self._decoratee.epo_reg_loss_tr + self.epo_task_loss_tr = self._decoratee.epo_task_loss_tr else: flag = super().tr_epoch(epoch, self.flag_setpoint_updated) # is it good to update setpoint after we know the new value of each loss? diff --git a/domainlab/algos/trainers/train_irm.py b/domainlab/algos/trainers/train_irm.py index aaa9f2dd0..09748ee1f 100644 --- a/domainlab/algos/trainers/train_irm.py +++ b/domainlab/algos/trainers/train_irm.py @@ -33,10 +33,12 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): with torch.enable_grad(): phi = self._cal_phi(tensor_x) dummy_w_scale = torch.tensor(1.).to(tensor_x.device).requires_grad_() + # interleave instances inside a minibatch loss_1 = F.cross_entropy(phi[::2] * dummy_w_scale, y[::2]) loss_2 = F.cross_entropy(phi[1::2] * dummy_w_scale, y[1::2]) grad_1 = autograd.grad(loss_1, [dummy_w_scale], create_graph=True)[0] grad_2 = autograd.grad(loss_2, [dummy_w_scale], create_graph=True)[0] loss_irm_scalar = torch.sum(grad_1 * grad_2) # scalar + loss_irm_scalar = torch.square(loss_irm_scalar) loss_irm_tensor = loss_irm_scalar.expand(tensor_x.shape[0]) return [loss_irm_tensor], [self.aconf.gamma_reg] From 1cb2428219f7b02dcd3ca2321048130f132d0d10 Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 11 Oct 2024 13:23:50 +0200 Subject: [PATCH 28/40] . --- examples/benchmark/aistat_irm_erm_mhof2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmark/aistat_irm_erm_mhof2.yaml b/examples/benchmark/aistat_irm_erm_mhof2.yaml index d9f8e4037..ffdde0425 100644 --- a/examples/benchmark/aistat_irm_erm_mhof2.yaml +++ b/examples/benchmark/aistat_irm_erm_mhof2.yaml @@ -1,6 +1,6 @@ mode: grid -output_dir: zoutput/benchmarks/benchmark_mhof_irm_erm_pacs +output_dir: zoutput/benchmarks/mhof_irm_erm_pacs sampling_seed: 0 startseed: 0 From 155f0df5dec9884431244329d522e3b8b8c3e093 Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 11 Oct 2024 13:27:42 +0200 Subject: [PATCH 29/40] . --- examples/benchmark/aistat_irm_erm_mhof2.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/benchmark/aistat_irm_erm_mhof2.yaml b/examples/benchmark/aistat_irm_erm_mhof2.yaml index ffdde0425..f6bc9467e 100644 --- a/examples/benchmark/aistat_irm_erm_mhof2.yaml +++ b/examples/benchmark/aistat_irm_erm_mhof2.yaml @@ -46,7 +46,6 @@ Shared params: distribution: categorical datatype: float values: - - 1 - 10 - 100 - 1000 @@ -73,6 +72,7 @@ Shared params: - 1 - 10 - 100 + - 1000 # 1778 is the largest gamma_reg using this sampling @@ -84,11 +84,11 @@ fbopt_irm_erm: trainer: fbopt_irm ini_setpoint_ratio: 0.99 force_setpoint_change_once: True - mu_clip: 10000 shared: - k_i_gain_ratio - nb4reg_over_task_ratio - mu_init + - mu_clip - exp_shoulder_clip irm_erm: From c41e92b5aae5f04756c5613c72d9932c4ac204ec Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 11 Oct 2024 14:31:52 +0200 Subject: [PATCH 30/40] . --- a_test_mhof_irm.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/a_test_mhof_irm.sh b/a_test_mhof_irm.sh index 85b9445cc..4168c5e8c 100644 --- a/a_test_mhof_irm.sh +++ b/a_test_mhof_irm.sh @@ -1 +1 @@ -python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --trainer=fbopt_irm --nname=conv_bn_pool_2 --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=500 +python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=fbopt_irm --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=500 --epos_min=400 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.99999999 From b8731cb29d973276d141449ae54cb82f928c446b Mon Sep 17 00:00:00 2001 From: smilesun Date: Sun, 13 Oct 2024 14:16:29 +0200 Subject: [PATCH 31/40] Milestone: feedforward works now with trainers --- a_test_mhof_irm.sh | 2 +- domainlab/algos/trainers/train_dial.py | 9 +++++---- domainlab/algos/trainers/train_hyper_scheduler.py | 12 ++++++++---- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/a_test_mhof_irm.sh b/a_test_mhof_irm.sh index 4168c5e8c..653dec138 100644 --- a/a_test_mhof_irm.sh +++ b/a_test_mhof_irm.sh @@ -1 +1 @@ -python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=fbopt_irm --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=500 --epos_min=400 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.99999999 +python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=fbopt_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=500 --epos_min=400 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.99999999 diff --git a/domainlab/algos/trainers/train_dial.py b/domainlab/algos/trainers/train_dial.py index 4fe700f45..711a83008 100644 --- a/domainlab/algos/trainers/train_dial.py +++ b/domainlab/algos/trainers/train_dial.py @@ -47,7 +47,8 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ _ = tensor_d _ = others - tensor_x_adv = self.gen_adversarial(self.device, tensor_x, tensor_y) - tensor_x_batch_adv_no_grad = Variable(tensor_x_adv, requires_grad=False) - loss_dial = self.model.cal_task_loss(tensor_x_batch_adv_no_grad, tensor_y) - return [loss_dial], [get_gamma_reg(self.aconf, self.name)] + with torch.enable_grad(): + tensor_x_adv = self.gen_adversarial(self.device, tensor_x, tensor_y) + tensor_x_batch_adv_no_grad = Variable(tensor_x_adv, requires_grad=False) + loss_dial = self.model.cal_task_loss(tensor_x_batch_adv_no_grad, tensor_y) + return [loss_dial], [get_gamma_reg(self.aconf, self.name)] diff --git a/domainlab/algos/trainers/train_hyper_scheduler.py b/domainlab/algos/trainers/train_hyper_scheduler.py index e595ea0fe..00b1cb1d3 100644 --- a/domainlab/algos/trainers/train_hyper_scheduler.py +++ b/domainlab/algos/trainers/train_hyper_scheduler.py @@ -25,7 +25,7 @@ def set_scheduler( flag_update_epoch: if hyper-parameters should be changed per epoch flag_update_batch: if hyper-parameters should be changed per batch """ - self.hyper_scheduler = self.model.hyper_init(scheduler) + self.hyper_scheduler = self.decoratee.hyper_init(scheduler, trainer=self) # let model register its hyper-parameters to the scheduler self.flag_update_hyper_per_epoch = flag_update_epoch self.flag_update_hyper_per_batch = flag_update_batch @@ -37,12 +37,14 @@ def before_batch(self, epoch, ind_batch): should be set to epoch*self.num_batches + ind_batch """ if self.flag_update_hyper_per_batch: - self.model.hyper_update( + self.decoratee.hyper_update( epoch * self.num_batches + ind_batch, self.hyper_scheduler ) return super().before_batch(epoch, ind_batch) def before_tr(self): + if hasattr(self.decoratee, "before_tr"): + self.decoratee.before_tr() if self.hyper_scheduler is None: logger = Logger.get_logger() logger.warning( @@ -54,12 +56,14 @@ def before_tr(self): total_steps=self.aconf.warmup, flag_update_epoch=True, ) - super().before_tr() def tr_epoch(self, epoch, flag_info=False): """ update hyper-parameters only per epoch """ if self.flag_update_hyper_per_epoch: - self.model.hyper_update(epoch, self.hyper_scheduler) + self.decoratee.hyper_update(epoch, self.hyper_scheduler) + if hasattr(self.decoratee, "dict_multiplier"): + logger = Logger.get_logger() + logger.info(f"---before epoch, current multiplier: {self.decoratee.dict_multiplier}") return super().tr_epoch(epoch) From f4e47732fd1e675d29b6ac71ffc8d85a91614627 Mon Sep 17 00:00:00 2001 From: smilesun Date: Sun, 13 Oct 2024 14:16:45 +0200 Subject: [PATCH 32/40] yaml for feedforward --- a_test_feedforward_irm.sh | 1 + ...istat_trainer_combo_dial_irm_erm_mhof.yaml | 102 ++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 a_test_feedforward_irm.sh create mode 100644 examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof.yaml diff --git a/a_test_feedforward_irm.sh b/a_test_feedforward_irm.sh new file mode 100644 index 000000000..007f3e9ac --- /dev/null +++ b/a_test_feedforward_irm.sh @@ -0,0 +1 @@ +python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=hyperscheduler_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=10 --epos_min=4 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.99999999 diff --git a/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof.yaml b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof.yaml new file mode 100644 index 000000000..919100cef --- /dev/null +++ b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof.yaml @@ -0,0 +1,102 @@ +mode: grid + +output_dir: zoutput/benchmarks/mhof_combo_irm_dial_pacs + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 100 + epos_min: 10 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + nb4reg_over_task_ratio: + distribution: categorical # name of the distribution + datatype: int + values: # concrete values to choose from + - 0 + - 100 + + mu_init: + distribution: categorical + datatype: float + values: + - 0.000001 + - 0.001 + + mu_clip: + distribution: categorical + datatype: float + values: + - 10 + - 100 + - 1000 + + exp_shoulder_clip: + distribution: categorical + datatype: float + values: + - 1 + - 5 + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 2 + distribution: uniform + + gamma_reg: + distribution: categorical + datatype: float + values: + - 0.01 + - 0.1 + - 1 + - 10 + - 100 + - 1000 + + +fbopt_irm_dial_erm: + model: erm + trainer: fbopt_irm_dial + ini_setpoint_ratio: 0.99 + force_setpoint_change_once: True + shared: + - k_i_gain_ratio + - nb4reg_over_task_ratio + - mu_init + - mu_clip + - exp_shoulder_clip + +irm_dial: + model: erm + trainer: irm_dial + shared: + - gamma_reg + - nb4reg_over_task_ratio + +irm_dial_feedforward_full: + model: erm + trainer: hyperscheduler_irm_dial + shared: + - gamma_reg + - nb4reg_over_task_ratio From e3ed8c6f9a4ba8567a824986940fa6b7a98573d9 Mon Sep 17 00:00:00 2001 From: smilesun Date: Tue, 15 Oct 2024 11:36:13 +0200 Subject: [PATCH 33/40] tr_with_init_mu --- a_test_mhof_irm.sh | 2 +- domainlab/algos/trainers/a_trainer.py | 11 ++- .../algos/trainers/fbopt_mu_controller.py | 6 +- domainlab/algos/trainers/train_basic.py | 6 -- domainlab/algos/trainers/train_fbopt_b.py | 1 + ...istat_trainer_combo_dial_irm_erm_mhof.yaml | 1 + ..._trainer_combo_dial_irm_erm_mhof_only.yaml | 89 +++++++++++++++++++ 7 files changed, 107 insertions(+), 9 deletions(-) create mode 100644 examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml diff --git a/a_test_mhof_irm.sh b/a_test_mhof_irm.sh index 653dec138..27d4e8658 100644 --- a/a_test_mhof_irm.sh +++ b/a_test_mhof_irm.sh @@ -1 +1 @@ -python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=fbopt_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=500 --epos_min=400 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.99999999 +python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=fbopt_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=500 --epos_min=400 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.99999999 --nb4reg_over_task_ratio=0 --tr_with_init_mu diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 99460afc9..24e0ccc21 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -323,7 +323,16 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): list_mu = [list_mu[i] for (i, flag) in enumerate(list_boolean_zero) if not flag] if self.dict_multiplier: list_mu = list(self.dict_multiplier.values()) - return list_loss_tensor, list_mu + + list_loss_tensor_normalized = list_loss_tensor + if self.list_reg_over_task_ratio: + assert len(list_mu) == len(self.list_reg_over_task_ratio) + list_loss_tensor_normalized = \ + [reg_loss / reg_over_task_ratio if reg_over_task_ratio != 0 + else reg_loss for (reg_loss, reg_over_task_ratio) + in zip(list_loss_tensor, self.list_reg_over_task_ratio)] + + return list_loss_tensor_normalized, list_mu def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py index bfdccbcbd..5c8efdcac 100644 --- a/domainlab/algos/trainers/fbopt_mu_controller.py +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -91,7 +91,11 @@ def set_k_i_gain(self, epo_reg_loss): k_i_gain_saturate_min = min(k_i_gain_saturate) # NOTE: here we override the commandline arguments specification # for k_i_control, so k_i_control is not a hyperparameter anymore - self.k_i_control = [self.k_i_gain_ratio * ele for ele in k_i_gain_saturate] + # self.k_i_control = [self.k_i_gain_ratio * ele for ele in k_i_gain_saturate] + # k_I should be the same for each component, the control error already + # make the multiplier magnification different + self.k_i_control = [self.k_i_gain_ratio * k_i_gain_saturate_min for i + in range(len(delta_epsilon_r))] warnings.warn( f"hyperparameter k_i_gain disabled! \ replace with {self.k_i_control}" diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index 0e7faf8e4..0e2cd5dec 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -96,12 +96,6 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others): tensor_x, tensor_y, tensor_d, others ) list_mu_tr_normalized = list_mu_tr - if self.list_reg_over_task_ratio: - assert len(list_mu_tr) == len(self.list_reg_over_task_ratio) - list_mu_tr_normalized = \ - [mu / reg_over_task_ratio if reg_over_task_ratio != 0 - else mu for (mu, reg_over_task_ratio) - in zip(list_mu_tr, self.list_reg_over_task_ratio)] tensor_batch_reg_loss_penalized = self.model.list_inner_product( list_reg_tr_batch, list_mu_tr_normalized ) diff --git a/domainlab/algos/trainers/train_fbopt_b.py b/domainlab/algos/trainers/train_fbopt_b.py index e786579c2..2c1483aa9 100644 --- a/domainlab/algos/trainers/train_fbopt_b.py +++ b/domainlab/algos/trainers/train_fbopt_b.py @@ -117,6 +117,7 @@ def before_tr(self): if self.aconf.tr_with_init_mu: self.tr_with_init_mu() + # evaluate regularization loss list ( self.epo_reg_loss_tr, self.epo_task_loss_tr, diff --git a/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof.yaml b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof.yaml index 919100cef..35ca668ba 100644 --- a/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof.yaml +++ b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof.yaml @@ -80,6 +80,7 @@ fbopt_irm_dial_erm: trainer: fbopt_irm_dial ini_setpoint_ratio: 0.99 force_setpoint_change_once: True + tr_with_init_mu: True shared: - k_i_gain_ratio - nb4reg_over_task_ratio diff --git a/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml new file mode 100644 index 000000000..81cc10609 --- /dev/null +++ b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml @@ -0,0 +1,89 @@ +mode: grid + +output_dir: zoutput/benchmarks/mhof_combo_irm_dial_pacs + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 100 + epos_min: 10 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + nb4reg_over_task_ratio: + distribution: categorical # name of the distribution + datatype: int + values: # concrete values to choose from + - 0 + - 100 + + mu_init: + distribution: categorical + datatype: float + values: + - 0.000001 + - 0.001 + + mu_clip: + distribution: categorical + datatype: float + values: + - 10 + - 100 + - 1000 + + exp_shoulder_clip: + distribution: categorical + datatype: float + values: + - 1 + - 5 + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 2 + distribution: uniform + + gamma_reg: + distribution: categorical + datatype: float + values: + - 0.01 + - 0.1 + - 1 + - 10 + - 100 + - 1000 + + +fbopt_irm_dial_erm: + model: erm + trainer: fbopt_irm_dial + ini_setpoint_ratio: 0.99 + force_setpoint_change_once: True + tr_with_init_mu: True + shared: + - k_i_gain_ratio + - nb4reg_over_task_ratio + - mu_init + - mu_clip + - exp_shoulder_clip From e38fa758cdbda89246281adc8578d2243ac5fdad Mon Sep 17 00:00:00 2001 From: smilesun Date: Tue, 15 Oct 2024 11:43:42 +0200 Subject: [PATCH 34/40] doc --- domainlab/algos/trainers/train_fbopt_b.py | 1 + 1 file changed, 1 insertion(+) diff --git a/domainlab/algos/trainers/train_fbopt_b.py b/domainlab/algos/trainers/train_fbopt_b.py index 2c1483aa9..ff9302274 100644 --- a/domainlab/algos/trainers/train_fbopt_b.py +++ b/domainlab/algos/trainers/train_fbopt_b.py @@ -106,6 +106,7 @@ def before_batch(self, epoch, ind_batch): def before_tr(self): if hasattr(self.decoratee, "before_tr"): + # initialize self.decoratee.dict_multiplier self.decoratee.before_tr() self.flag_setpoint_updated = False if self.aconf.force_feedforward: From 044e3ccb514c892ebedd82a87fe1648fc95c054d Mon Sep 17 00:00:00 2001 From: smilesun Date: Tue, 15 Oct 2024 11:52:24 +0200 Subject: [PATCH 35/40] . --- .../benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml index 81cc10609..9ecad287b 100644 --- a/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml +++ b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml @@ -1,6 +1,6 @@ mode: grid -output_dir: zoutput/benchmarks/mhof_combo_irm_dial_pacs +output_dir: zoutput/benchmarks/only_mhof_combo_irm_dial_pacs sampling_seed: 0 startseed: 0 From aea6e907316118e6d5b482bb92360c75ff4aedb5 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 17 Oct 2024 10:29:03 +0200 Subject: [PATCH 36/40] no ma for setpoint --- a_test_mhof_irm.sh | 2 +- domainlab/algos/trainers/fbopt_mu_controller.py | 2 ++ domainlab/algos/trainers/train_fbopt_b.py | 11 +++++++++-- .../aistat_trainer_combo_dial_irm_erm_mhof_only.yaml | 3 ++- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/a_test_mhof_irm.sh b/a_test_mhof_irm.sh index 27d4e8658..df138e487 100644 --- a/a_test_mhof_irm.sh +++ b/a_test_mhof_irm.sh @@ -1 +1 @@ -python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=fbopt_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=500 --epos_min=400 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.99999999 --nb4reg_over_task_ratio=0 --tr_with_init_mu +python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=fbopt_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=500 --epos_min=400 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.9 --nb4reg_over_task_ratio=0 --tr_with_init_mu --coeff_ma_setpoint=0.0 diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py index 5c8efdcac..ce53fc0df 100644 --- a/domainlab/algos/trainers/fbopt_mu_controller.py +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -88,6 +88,8 @@ def set_k_i_gain(self, epo_reg_loss): a / b for a, b in zip(list_active, delta_epsilon_r) ] + # FIXME: add max K_I gain here if initial delta is too small + k_i_gain_saturate_min = min(k_i_gain_saturate) # NOTE: here we override the commandline arguments specification # for k_i_control, so k_i_control is not a hyperparameter anymore diff --git a/domainlab/algos/trainers/train_fbopt_b.py b/domainlab/algos/trainers/train_fbopt_b.py index ff9302274..b47ffc69e 100644 --- a/domainlab/algos/trainers/train_fbopt_b.py +++ b/domainlab/algos/trainers/train_fbopt_b.py @@ -115,8 +115,6 @@ def before_tr(self): self.set_scheduler(scheduler=HyperSchedulerFeedback) self.set_model_with_mu() # very small value - if self.aconf.tr_with_init_mu: - self.tr_with_init_mu() # evaluate regularization loss list ( @@ -125,6 +123,15 @@ def before_tr(self): self.epo_loss_tr, ) = self.eval_r_loss() + if self.aconf.tr_with_init_mu: + self.tr_with_init_mu() + # evaluate regularization loss list + ( + self.epo_reg_loss_tr, + self.epo_task_loss_tr, + self.epo_loss_tr, + ) = self.eval_r_loss() + self.hyper_scheduler.set_setpoint( [ ele * self.aconf.ini_setpoint_ratio diff --git a/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml index 9ecad287b..4dd325e7d 100644 --- a/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml +++ b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml @@ -80,10 +80,11 @@ fbopt_irm_dial_erm: trainer: fbopt_irm_dial ini_setpoint_ratio: 0.99 force_setpoint_change_once: True + coeff_ma_setpoint: 0.0 tr_with_init_mu: True + nb4reg_over_task_ratio: 0 shared: - k_i_gain_ratio - - nb4reg_over_task_ratio - mu_init - mu_clip - exp_shoulder_clip From e5197f3d7efe63ca918c7b9be601ce5f5f27323f Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 17 Oct 2024 12:09:43 +0200 Subject: [PATCH 37/40] logger --- domainlab/algos/trainers/fbopt_setpoint_ada.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/domainlab/algos/trainers/fbopt_setpoint_ada.py b/domainlab/algos/trainers/fbopt_setpoint_ada.py index d243fc5e8..f0e5f6a9a 100644 --- a/domainlab/algos/trainers/fbopt_setpoint_ada.py +++ b/domainlab/algos/trainers/fbopt_setpoint_ada.py @@ -311,4 +311,12 @@ def update_setpoint(self): {self.host.state_task_loss}" ) self.host.setpoint4ell = self.host.state_task_loss + + if flag1 & flag2: + logger = Logger.get_logger(logger_name="main_out_logger", loglevel="INFO") + logger.info( + "!!!!!!!!!In DominantAllComponent: \ + besides each components of reg loss shrinks \ + task loss also decreased which forms dominance!" + ) return flag1 & flag2, list_pos From dc58a5008819c31908f244948e1195a28f1a71b5 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 17 Oct 2024 12:16:36 +0200 Subject: [PATCH 38/40] setpoint ada as argument --- a_test_mhof_irm.sh | 2 +- domainlab/algos/trainers/args_fbopt.py | 8 ++++++++ domainlab/algos/trainers/fbopt_setpoint_ada.py | 4 +++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/a_test_mhof_irm.sh b/a_test_mhof_irm.sh index df138e487..3cf1ae7ed 100644 --- a/a_test_mhof_irm.sh +++ b/a_test_mhof_irm.sh @@ -1 +1 @@ -python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=fbopt_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=500 --epos_min=400 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.9 --nb4reg_over_task_ratio=0 --tr_with_init_mu --coeff_ma_setpoint=0.0 +python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=fbopt_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=500 --epos_min=400 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.9 --nb4reg_over_task_ratio=0 --tr_with_init_mu --coeff_ma_setpoint=0.0 --str_setpoint_ada="SliderAnyComponent()" diff --git a/domainlab/algos/trainers/args_fbopt.py b/domainlab/algos/trainers/args_fbopt.py index c738d76ff..e1fa536ac 100644 --- a/domainlab/algos/trainers/args_fbopt.py +++ b/domainlab/algos/trainers/args_fbopt.py @@ -137,6 +137,14 @@ def add_args2parser_fbopt(parser): help="rewind setpoint", ) + # this arg is only used when model is set to be "diva" + parser.add_argument( + "--str_setpoint_ada", + type=str, + default="DominateAllComponent()", + help="which setpoint adaptation criteria to use", + ) + # this arg is only used when model is set to be "diva" parser.add_argument( "--str_diva_multiplier_type", diff --git a/domainlab/algos/trainers/fbopt_setpoint_ada.py b/domainlab/algos/trainers/fbopt_setpoint_ada.py index f0e5f6a9a..e22e13a75 100644 --- a/domainlab/algos/trainers/fbopt_setpoint_ada.py +++ b/domainlab/algos/trainers/fbopt_setpoint_ada.py @@ -144,7 +144,9 @@ def __init__(self, state=None, args=None): if args is not None and args.no_setpoint_update: state = FixedSetpoint() else: - state = DominateAllComponent() + # state = eval('DominateAllComponent()') + # state = DominateAllComponent() + state = eval(args.str_setpoint_ada) self.transition_to(state) self.flag_setpoint_rewind = args.setpoint_rewind self.setpoint_rewinder = SetpointRewinder(self) From 9ddd0373832bef1e3d92d9c5e1882d560a31a5c8 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 17 Oct 2024 12:20:26 +0200 Subject: [PATCH 39/40] yaml file search setpoint ada --- .../aistat_trainer_combo_dial_irm_erm_mhof_only.yaml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml index 4dd325e7d..f1221b2f8 100644 --- a/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml +++ b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml @@ -74,6 +74,14 @@ Shared params: - 100 - 1000 + gamma_reg: + distribution: categorical + datatype: str + values: + - "SliderAnyComponent()" + - "SliderAllComponent()" + - "DominateAnyComponent()" + fbopt_irm_dial_erm: model: erm @@ -88,3 +96,4 @@ fbopt_irm_dial_erm: - mu_init - mu_clip - exp_shoulder_clip + - str_setpoint_ada From 83143a98253eaeafe3ee96c7b2708e6d4e4b90b6 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Thu, 17 Oct 2024 12:32:31 +0200 Subject: [PATCH 40/40] Update a_trainer.py: change mu order --- domainlab/algos/trainers/a_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 24e0ccc21..49c0c4077 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -322,7 +322,9 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): enumerate(list_boolean_zero) if not flag] list_mu = [list_mu[i] for (i, flag) in enumerate(list_boolean_zero) if not flag] if self.dict_multiplier: - list_mu = list(self.dict_multiplier.values()) + #list_mu = list(self.dict_multiplier.values()) + # use unique order of keys to keep the order of list_mu consistent each time + list_mu = [self.dict_multiplier[key] for key in sorted(self.dict_multiplier.keys())] list_loss_tensor_normalized = list_loss_tensor if self.list_reg_over_task_ratio: