diff --git a/a_test_feedforward_irm.sh b/a_test_feedforward_irm.sh new file mode 100644 index 000000000..007f3e9ac --- /dev/null +++ b/a_test_feedforward_irm.sh @@ -0,0 +1 @@ +python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=hyperscheduler_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=10 --epos_min=4 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.99999999 diff --git a/a_test_mhof_irm.sh b/a_test_mhof_irm.sh index 610d3a606..3cf1ae7ed 100644 --- a/a_test_mhof_irm.sh +++ b/a_test_mhof_irm.sh @@ -1 +1 @@ -python main_out.py --te_d=0 --task=mnistcolor10 --model=erm --trainer=fbopt_irm --nname=conv_bn_pool_2 --k_i_gain_ratio=0.5 +python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=fbopt_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=500 --epos_min=400 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.9 --nb4reg_over_task_ratio=0 --tr_with_init_mu --coeff_ma_setpoint=0.0 --str_setpoint_ada="SliderAnyComponent()" diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 2503f6024..49c0c4077 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -53,6 +53,7 @@ def __init__(self, successor_node=None, extend=None): """ super().__init__(successor_node) self._model = None + # decoratee can be both model or trainer self._decoratee = extend self.task = None self.observer = None @@ -96,6 +97,8 @@ def __init__(self, successor_node=None, extend=None): self._ma_iter = 0 # self.list_reg_over_task_ratio = None + # mhof + self.dict_multiplier = {} @property @@ -199,7 +202,13 @@ def before_tr(self): """ before training, probe model performance """ - self.cal_reg_loss_over_task_loss_ratio() + list_mu = self.cal_reg_loss_over_task_loss_ratio() + self.dict_multiplier = {'mu4regloss'+str(i): value for i, value in enumerate(list_mu)} + + @property + def list_str_multiplier_na(self): + list_str = list(self.dict_multiplier.keys()) + return list_str def cal_reg_loss_over_task_loss_ratio(self): """ @@ -208,18 +217,21 @@ def cal_reg_loss_over_task_loss_ratio(self): """ list_accum_reg_loss = [] loss_task_agg = 0 + list_mu = None for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( self.loader_tr ): - if ind_batch >= self.aconf.nb4reg_over_task_ratio: - return tensor_x, tensor_y, tensor_d = ( tensor_x.to(self.device), tensor_y.to(self.device), tensor_d.to(self.device), ) - list_reg_loss_tensor, _ = \ + list_reg_loss_tensor, list_mu = \ self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + + if ind_batch >= self.aconf.nb4reg_over_task_ratio: + return list_mu + list_reg_loss_tensor = [torch.sum(tensor).detach().item() for tensor in list_reg_loss_tensor] if ind_batch == 0: @@ -235,6 +247,7 @@ def cal_reg_loss_over_task_loss_ratio(self): loss_task_agg += tensor_loss_task self.list_reg_over_task_ratio = [reg_loss / loss_task_agg for reg_loss in list_accum_reg_loss] + return list_mu def post_tr(self): """ @@ -301,7 +314,27 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): list_loss_tensor = list_reg_loss_model_tensor + \ list_reg_loss_trainer_tensor list_mu = list_mu_model + list_mu_trainer - return list_loss_tensor, list_mu + # ERM return a tensor of all zeros, delete here + if len(list_mu) > 1: + list_boolean_zero = [torch.all(torch.eq(list_loss_tensor[i], 0)).item() + for i in range(len(list_mu))] + list_loss_tensor = [list_loss_tensor[i] for (i, flag) in + enumerate(list_boolean_zero) if not flag] + list_mu = [list_mu[i] for (i, flag) in enumerate(list_boolean_zero) if not flag] + if self.dict_multiplier: + #list_mu = list(self.dict_multiplier.values()) + # use unique order of keys to keep the order of list_mu consistent each time + list_mu = [self.dict_multiplier[key] for key in sorted(self.dict_multiplier.keys())] + + list_loss_tensor_normalized = list_loss_tensor + if self.list_reg_over_task_ratio: + assert len(list_mu) == len(self.list_reg_over_task_ratio) + list_loss_tensor_normalized = \ + [reg_loss / reg_over_task_ratio if reg_over_task_ratio != 0 + else reg_loss for (reg_loss, reg_over_task_ratio) + in zip(list_loss_tensor, self.list_reg_over_task_ratio)] + + return list_loss_tensor_normalized, list_mu def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ @@ -326,3 +359,23 @@ def print_parameters(self): """ params = vars(self) print(f"Parameters of {type(self).__name__}: {params}") + + def hyper_init(self, functor_scheduler, trainer): + """ + initialize both trainer's multiplier and model's multiplier + """ + if not self.dict_multiplier: + raise RuntimeError("self.dict_multiplier empty!") + return functor_scheduler( + trainer=trainer, **self.dict_multiplier + ) + + def hyper_update(self, epoch, fun_scheduler): + """hyper_update. + + :param epoch: + :param fun_scheduler: + """ + dict_rst = fun_scheduler(epoch) + for key in self.dict_multiplier: + self.dict_multiplier[key] = dict_rst[key] diff --git a/domainlab/algos/trainers/args_fbopt.py b/domainlab/algos/trainers/args_fbopt.py index 53719e05f..e1fa536ac 100644 --- a/domainlab/algos/trainers/args_fbopt.py +++ b/domainlab/algos/trainers/args_fbopt.py @@ -9,53 +9,73 @@ def add_args2parser_fbopt(parser): """ parser.add_argument( - "--k_i_gain", type=float, default=0.001, help="PID control gain for integrator" + "--k_i_gain", type=float, default=0.001, + help="PID control gain for integrator, if k_i_gain_ratio is not None, \ + then this value will be overwriten, see doc for k_i_gain_ratio" ) parser.add_argument( "--k_i_gain_ratio", type=float, default=None, - help="set k_i_gain to be ratio of \ - initial saturation k_i_gain", + help="set k_i_gain to be ratio of initial saturation k_i_gain \ + which K_I * delta = exp_shoulder_clip(saturation value), solve \ + for K_I, where delta = reg loss - setpoint. \ + Now independent of the scale of delta, the K_I gain will be set so \ + that the multiplier will be increased at a rate defined by \ + exp_shoulder_clip", ) parser.add_argument( - "--mu_clip", type=float, default=1e4, help="maximum value of mu" + "--mu_clip", type=float, default=1e4, + help="maximum value of mu: mu_clip should be large enough so that the \ + regularization loss as penalty can be weighed superior enough to \ + decrease." ) parser.add_argument( - "--mu_min", type=float, default=1e-6, help="minimum value of mu" + "--mu_min", type=float, default=1e-6, help="minimum value of mu, mu \ + can not be negative" ) parser.add_argument( - "--mu_init", type=float, default=0.001, help="initial beta for multiplication" + "--mu_init", type=float, default=0.001, + help="initial value for each component of the multiplier vector" ) parser.add_argument( - "--coeff_ma", type=float, default=0.5, help="exponential moving average" + "--coeff_ma", type=float, default=0.5, + help="exponential moving average of delta \ + (reg minus setpoint as control error): \ + move_ave=move_ave + coeff*delta(current value)" ) parser.add_argument( "--coeff_ma_output_state", type=float, default=0.1, - help="state exponential moving average of \ - reguarlization loss", + help="output (reguarization loss) exponential moving average \ + move_ave=move_ave*coeef + reg(current value)", ) parser.add_argument( "--coeff_ma_setpoint", type=float, default=0.9, - help="setpoint average coeff for previous setpoint", + help="setpoint average (coeff for previous setpoint)", ) parser.add_argument( "--exp_shoulder_clip", type=float, default=5, - help="clip before exponential operation", + help="clip delta(control error): \ + R(reg-loss)-b(setpoint) before exponential operation: \ + exp[clip(R-b, exp_shoulder_clip)].\ + exponential magnifies control error, so this argument \ + defines the maximum rate of change of multipliers \ + exp(5)=148, exp_shoulder_clip should not be too big, \ + if exp_shoulder_clip is small, then more like exterior point method", ) parser.add_argument( @@ -63,7 +83,7 @@ def add_args2parser_fbopt(parser): type=float, default=0.99, help="before training start, evaluate reg loss, \ - setpoint will be 0.9 of this loss", + setpoint will be 0.9 of this loss", ) parser.add_argument( @@ -77,8 +97,8 @@ def add_args2parser_fbopt(parser): "--force_setpoint_change_once", action="store_true", default=False, - help="train until the setpoint changed at least once \ - up to maximum epos specified", + help="continue training until the setpoint changed at least once: \ + up to maximum epos specified", ) parser.add_argument( @@ -103,24 +123,34 @@ def add_args2parser_fbopt(parser): ) parser.add_argument( - "--overshoot_rewind", - type=str, - default="yes", - help="overshoot_rewind, for benchmark, use yes or no", + "--no_overshoot_rewind", + action="store_true", + default=False, + help="disable overshoot rewind: when reg loss satisfies setpoint \ + already, then set activation=K_I*delta = 0", ) parser.add_argument( "--setpoint_rewind", + action="store_true", + default=False, + help="rewind setpoint", + ) + + # this arg is only used when model is set to be "diva" + parser.add_argument( + "--str_setpoint_ada", type=str, - default="no", - help="setpoing_rewind, for benchmark, use yes or no", + default="DominateAllComponent()", + help="which setpoint adaptation criteria to use", ) + # this arg is only used when model is set to be "diva" parser.add_argument( "--str_diva_multiplier_type", type=str, default="gammad_recon", - help="which penalty to tune", + help="which penalty to tune, only useful to DIVA model", ) return parser diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py index 824638461..ce53fc0df 100644 --- a/domainlab/algos/trainers/fbopt_mu_controller.py +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -50,15 +50,16 @@ def __init__(self, trainer, **kwargs): self.mmu = {key: self.init_mu for key, val in self.mmu.items()} self.set_point_controller = FbOptSetpointController(args=self.trainer.aconf) - self.k_i_control = trainer.aconf.k_i_gain - self.k_i_gain_ratio = None - self.overshoot_rewind = trainer.aconf.overshoot_rewind == "yes" + self.k_i_control = [trainer.aconf.k_i_gain for i in + range(len(self.mmu))] + self.k_i_gain_ratio = trainer.aconf.k_i_gain_ratio + self.overshoot_rewind = not trainer.aconf.no_overshoot_rewind self.delta_epsilon_r = None # NOTE: this value will be set according to initial evaluation of # neural network self.activation_clip = trainer.aconf.exp_shoulder_clip - self.coeff_ma = trainer.aconf.coeff_ma + self.coeff4newval_ma_delta = trainer.aconf.coeff_ma # NOTE: # print(copy.deepcopy(self.model)) # TypeError: cannot pickle '_thread.lock' object @@ -70,7 +71,10 @@ def __init__(self, trainer, **kwargs): def set_k_i_gain(self, epo_reg_loss): if self.k_i_gain_ratio is None: - return + if self.k_i_control: + return + raise RuntimeError("set either direct k_i_control value or \ + set k_i_gain_ratio, can not be both empty!") # NOTE: do not use self.cal_delta4control!!!! which will change # class member variables self.delta_epsilon_r! list_setpoint = self.get_setpoint4r() @@ -78,13 +82,22 @@ def set_k_i_gain(self, epo_reg_loss): delta_epsilon_r = [a - b for a, b in zip(epo_reg_loss, list_setpoint)] # to calculate self.delta_epsilon_r + list_active = [self.activation_clip for i in range(len(delta_epsilon_r))] + k_i_gain_saturate = [ - a / b for a, b in zip(self.activation_clip, delta_epsilon_r) + a / b for a, b in zip(list_active, delta_epsilon_r) ] + + # FIXME: add max K_I gain here if initial delta is too small + k_i_gain_saturate_min = min(k_i_gain_saturate) # NOTE: here we override the commandline arguments specification # for k_i_control, so k_i_control is not a hyperparameter anymore - self.k_i_control = self.k_i_gain_ratio * k_i_gain_saturate_min + # self.k_i_control = [self.k_i_gain_ratio * ele for ele in k_i_gain_saturate] + # k_I should be the same for each component, the control error already + # make the multiplier magnification different + self.k_i_control = [self.k_i_gain_ratio * k_i_gain_saturate_min for i + in range(len(delta_epsilon_r))] warnings.warn( f"hyperparameter k_i_gain disabled! \ replace with {self.k_i_control}" @@ -118,7 +131,8 @@ def cal_delta4control(self, list1, list_setpoint): # self.delta_epsilon_r is the previous time step. # delta_epsilon_r is the current time step self.delta_epsilon_r = self.cal_delta_integration( - self.delta_epsilon_r, delta_epsilon_r, self.coeff_ma + self.delta_epsilon_r, delta_epsilon_r, + self.coeff4newval_ma_delta ) def cal_delta_integration(self, list_old, list_new, coeff): @@ -162,7 +176,7 @@ def cal_activation(self): """ setpoint = self.get_setpoint4r() activation = [ - self.k_i_control * val if setpoint[i] > 0 else self.k_i_control * (-val) + self.k_i_control[i] * val if setpoint[i] > 0 else self.k_i_control[i] * (-val) for i, val in enumerate(self.delta_epsilon_r) ] if self.activation_clip is not None: diff --git a/domainlab/algos/trainers/fbopt_setpoint_ada.py b/domainlab/algos/trainers/fbopt_setpoint_ada.py index c3c0193ce..e22e13a75 100644 --- a/domainlab/algos/trainers/fbopt_setpoint_ada.py +++ b/domainlab/algos/trainers/fbopt_setpoint_ada.py @@ -81,7 +81,7 @@ def __init__(self, host): self.counter = None self.epo_ma = None self.ref = None - self.coeff_ma = 0.5 + self.coeff_ma_setpoint_rewinder = 0.5 self.setpoint_rewind = host.flag_setpoint_rewind def reset(self, epo_reg_loss): @@ -98,7 +98,7 @@ def observe(self, epo_reg_loss): """ if self.ref is None: self.reset(epo_reg_loss) - self.epo_ma = list_ma(self.epo_ma, epo_reg_loss, self.coeff_ma) + self.epo_ma = list_ma(self.epo_ma, epo_reg_loss, self.coeff_ma_setpoint_rewinder) list_comparison_increase = [a < b for a, b in zip(self.ref, self.epo_ma)] list_comparison_above_setpoint = [ a < b for a, b in zip(self.host.setpoint4R, self.epo_ma) @@ -144,9 +144,11 @@ def __init__(self, state=None, args=None): if args is not None and args.no_setpoint_update: state = FixedSetpoint() else: - state = DominateAllComponent() + # state = eval('DominateAllComponent()') + # state = DominateAllComponent() + state = eval(args.str_setpoint_ada) self.transition_to(state) - self.flag_setpoint_rewind = args.setpoint_rewind == "yes" + self.flag_setpoint_rewind = args.setpoint_rewind self.setpoint_rewinder = SetpointRewinder(self) self.state_task_loss = 0.0 self.state_epo_reg_loss = [ @@ -311,4 +313,12 @@ def update_setpoint(self): {self.host.state_task_loss}" ) self.host.setpoint4ell = self.host.state_task_loss + + if flag1 & flag2: + logger = Logger.get_logger(logger_name="main_out_logger", loglevel="INFO") + logger.info( + "!!!!!!!!!In DominantAllComponent: \ + besides each components of reg loss shrinks \ + task loss also decreased which forms dominance!" + ) return flag1 & flag2, list_pos diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index 100c274b4..0e2cd5dec 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -23,6 +23,7 @@ def before_tr(self): """ self.model.evaluate(self.loader_te, self.device) super().before_tr() + self.before_epoch() def before_epoch(self): """ @@ -79,7 +80,7 @@ def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch): tensor_d.to(self.device), ) self.optimizer.zero_grad() - loss = self.cal_loss(tensor_x, tensor_y, tensor_d, others) + loss, *_ = self.cal_loss(tensor_x, tensor_y, tensor_d, others) loss.backward() self.optimizer.step() self.after_batch(epoch, ind_batch) @@ -94,17 +95,11 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others): list_reg_tr_batch, list_mu_tr = self.cal_reg_loss( tensor_x, tensor_y, tensor_d, others ) - list_mu_tr_normalized = list_mu_tr - if self.list_reg_over_task_ratio: - assert len(list_mu_tr) == len(self.list_reg_over_task_ratio) - list_mu_tr_normalized = \ - [mu / reg_over_task_ratio if reg_over_task_ratio != 0 - else mu for (mu, reg_over_task_ratio) - in zip(list_mu_tr, self.list_reg_over_task_ratio)] tensor_batch_reg_loss_penalized = self.model.list_inner_product( list_reg_tr_batch, list_mu_tr_normalized ) + assert len(tensor_batch_reg_loss_penalized.shape) == 1 loss_erm_agg = g_tensor_batch_agg(loss_task) loss_reg_penalized_agg = g_tensor_batch_agg(tensor_batch_reg_loss_penalized) @@ -112,4 +107,4 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others): self.model.multiplier4task_loss * loss_erm_agg + loss_reg_penalized_agg ) self.log_loss(list_reg_tr_batch, loss_task, loss_penalized) - return loss_penalized + return loss_penalized, list_reg_tr_batch, loss_erm_agg diff --git a/domainlab/algos/trainers/train_dial.py b/domainlab/algos/trainers/train_dial.py index dbcb50eae..711a83008 100644 --- a/domainlab/algos/trainers/train_dial.py +++ b/domainlab/algos/trainers/train_dial.py @@ -47,15 +47,8 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ _ = tensor_d _ = others - tensor_x_adv = self.gen_adversarial(self.device, tensor_x, tensor_y) - tensor_x_batch_adv_no_grad = Variable(tensor_x_adv, requires_grad=False) - loss_dial = self.model.cal_task_loss(tensor_x_batch_adv_no_grad, tensor_y) - return [loss_dial], [get_gamma_reg(self.aconf, self.name)] - - def hyper_init(self, functor_scheduler, trainer): - """ - initialize both trainer's multiplier and model's multiplier - """ - fun_scheduler = super().hyper_init(functor_scheduler, trainer) - return fun_scheduler - # FIXME: register also the trainer hyperpars: return functor_scheduler(trainer=trainer, gamma_reg=self.gamma_reg) + with torch.enable_grad(): + tensor_x_adv = self.gen_adversarial(self.device, tensor_x, tensor_y) + tensor_x_batch_adv_no_grad = Variable(tensor_x_adv, requires_grad=False) + loss_dial = self.model.cal_task_loss(tensor_x_batch_adv_no_grad, tensor_y) + return [loss_dial], [get_gamma_reg(self.aconf, self.name)] diff --git a/domainlab/algos/trainers/train_fbopt_b.py b/domainlab/algos/trainers/train_fbopt_b.py index 1efe3ce58..b47ffc69e 100644 --- a/domainlab/algos/trainers/train_fbopt_b.py +++ b/domainlab/algos/trainers/train_fbopt_b.py @@ -43,7 +43,7 @@ def set_scheduler(self, scheduler): this class name will be created inside model """ # model.hyper_init will register the hyper-parameters of the model to scheduler - self.hyper_scheduler = self.model.hyper_init(scheduler, trainer=self) + self.hyper_scheduler = self.decoratee.hyper_init(scheduler, trainer=self) def eval_r_loss(self): """ @@ -66,8 +66,10 @@ def eval_r_loss(self): vec_y.to(self.device), vec_d.to(self.device), ) - tuple_reg_loss = self.model.cal_reg_loss(tensor_x, vec_y, vec_d, others) - p_loss, *_ = self.model.cal_loss(tensor_x, vec_y, vec_d, others) + tuple_reg_loss = self.decoratee.cal_reg_loss(tensor_x, vec_y, vec_d, others) + p_loss, *_ = self.decoratee.cal_loss(tensor_x, vec_y, vec_d, others) + if p_loss.dim() > 0: + p_loss = p_loss.sum() # NOTE: first [0] extract the loss, second [0] get the list list_b_reg_loss = tuple_reg_loss[0] list_b_reg_loss_sumed = [ @@ -82,7 +84,7 @@ def eval_r_loss(self): ) # sum will kill the dimension of the mini batch epo_task_loss += b_task_loss - epo_p_loss += p_loss.sum().detach().item() + epo_p_loss += p_loss.detach().item() counter += 1.0 return ( list_divide(epo_reg_loss, counter), @@ -103,6 +105,9 @@ def before_batch(self, epoch, ind_batch): return super().after_batch(epoch, ind_batch) def before_tr(self): + if hasattr(self.decoratee, "before_tr"): + # initialize self.decoratee.dict_multiplier + self.decoratee.before_tr() self.flag_setpoint_updated = False if self.aconf.force_feedforward: self.set_scheduler(scheduler=HyperSchedulerWarmupLinear) @@ -110,14 +115,23 @@ def before_tr(self): self.set_scheduler(scheduler=HyperSchedulerFeedback) self.set_model_with_mu() # very small value - if self.aconf.tr_with_init_mu: - self.tr_with_init_mu() + # evaluate regularization loss list ( self.epo_reg_loss_tr, self.epo_task_loss_tr, self.epo_loss_tr, ) = self.eval_r_loss() + + if self.aconf.tr_with_init_mu: + self.tr_with_init_mu() + # evaluate regularization loss list + ( + self.epo_reg_loss_tr, + self.epo_task_loss_tr, + self.epo_loss_tr, + ) = self.eval_r_loss() + self.hyper_scheduler.set_setpoint( [ ele * self.aconf.ini_setpoint_ratio @@ -126,7 +140,7 @@ def before_tr(self): for ele in self.epo_reg_loss_tr ], self.epo_task_loss_tr, - ) # setpoing w.r.t. random initialization of neural network + ) # setpoint w.r.t. random initialization of neural network self.hyper_scheduler.set_k_i_gain(self.epo_reg_loss_tr) @property @@ -134,7 +148,7 @@ def list_str_multiplier_na(self): """ return the name of multipliers """ - return self.model.list_str_multiplier_na + return self.decoratee.list_str_multiplier_na def tr_with_init_mu(self): """ @@ -146,7 +160,7 @@ def set_model_with_mu(self): """ set model multipliers """ - self.model.hyper_update( + self.decoratee.hyper_update( epoch=None, fun_scheduler=HyperSetter(self.hyper_scheduler.mmu) ) @@ -162,12 +176,19 @@ def tr_epoch(self, epoch, flag_info=False): miter=epoch, ) self.set_model_with_mu() - if hasattr(self.model, "dict_multiplier"): + if hasattr(self.decoratee, "dict_multiplier"): logger = Logger.get_logger() - logger.info(f"current multiplier: {self.model.dict_multiplier}") + logger.info(f"current multiplier: {self.decoratee.dict_multiplier}") if self._decoratee is not None: flag = self._decoratee.tr_epoch(epoch, self.flag_setpoint_updated) + # self._decoratee.tr_epoch here will call + # self._decoratee.after_epoch to log the losses, but it only sotre + # the value into self._decoratee, + # so we have to mannually copy the value here + self.epo_loss_tr = self._decoratee.epo_loss_tr + self.epo_reg_loss_tr = self._decoratee.epo_reg_loss_tr + self.epo_task_loss_tr = self._decoratee.epo_task_loss_tr else: flag = super().tr_epoch(epoch, self.flag_setpoint_updated) # is it good to update setpoint after we know the new value of each loss? diff --git a/domainlab/algos/trainers/train_hyper_scheduler.py b/domainlab/algos/trainers/train_hyper_scheduler.py index e595ea0fe..00b1cb1d3 100644 --- a/domainlab/algos/trainers/train_hyper_scheduler.py +++ b/domainlab/algos/trainers/train_hyper_scheduler.py @@ -25,7 +25,7 @@ def set_scheduler( flag_update_epoch: if hyper-parameters should be changed per epoch flag_update_batch: if hyper-parameters should be changed per batch """ - self.hyper_scheduler = self.model.hyper_init(scheduler) + self.hyper_scheduler = self.decoratee.hyper_init(scheduler, trainer=self) # let model register its hyper-parameters to the scheduler self.flag_update_hyper_per_epoch = flag_update_epoch self.flag_update_hyper_per_batch = flag_update_batch @@ -37,12 +37,14 @@ def before_batch(self, epoch, ind_batch): should be set to epoch*self.num_batches + ind_batch """ if self.flag_update_hyper_per_batch: - self.model.hyper_update( + self.decoratee.hyper_update( epoch * self.num_batches + ind_batch, self.hyper_scheduler ) return super().before_batch(epoch, ind_batch) def before_tr(self): + if hasattr(self.decoratee, "before_tr"): + self.decoratee.before_tr() if self.hyper_scheduler is None: logger = Logger.get_logger() logger.warning( @@ -54,12 +56,14 @@ def before_tr(self): total_steps=self.aconf.warmup, flag_update_epoch=True, ) - super().before_tr() def tr_epoch(self, epoch, flag_info=False): """ update hyper-parameters only per epoch """ if self.flag_update_hyper_per_epoch: - self.model.hyper_update(epoch, self.hyper_scheduler) + self.decoratee.hyper_update(epoch, self.hyper_scheduler) + if hasattr(self.decoratee, "dict_multiplier"): + logger = Logger.get_logger() + logger.info(f"---before epoch, current multiplier: {self.decoratee.dict_multiplier}") return super().tr_epoch(epoch) diff --git a/domainlab/algos/trainers/train_irm.py b/domainlab/algos/trainers/train_irm.py index 45797bf00..09748ee1f 100644 --- a/domainlab/algos/trainers/train_irm.py +++ b/domainlab/algos/trainers/train_irm.py @@ -19,36 +19,6 @@ class TrainerIRM(TrainerBasic): For more details, see section 3.2 and Appendix D of : Arjovsky et al., “Invariant Risk Minimization.” """ - def tr_epoch(self, epoch, flag_info=False): - list_loaders = list(self.dict_loader_tr.values()) - loaders_zip = zip(*list_loaders) - self.model.train() - self.epo_loss_tr = 0 - - for ind_batch, tuple_data_domains_batch in enumerate(loaders_zip): - self.optimizer.zero_grad() - list_domain_loss_erm = [] - list_domain_reg = [] - for batch_domain_e in tuple_data_domains_batch: - tensor_x, tensor_y, tensor_d, *others = batch_domain_e - tensor_x, tensor_y, tensor_d = \ - tensor_x.to(self.device), tensor_y.to(self.device), \ - tensor_d.to(self.device) - list_domain_loss_erm.append( - self.model.cal_task_loss(tensor_x, tensor_y)) - list_1ele_loss_irm, _ = \ - self._cal_reg_loss(tensor_x, tensor_y, tensor_d, others) - list_domain_reg += list_1ele_loss_irm - loss = torch.sum(torch.stack(list_domain_loss_erm)) + \ - self.aconf.gamma_reg * torch.sum(torch.stack(list_domain_reg)) - loss.backward() - self.optimizer.step() - self.epo_loss_tr += loss.detach().item() - self.after_batch(epoch, ind_batch) - - flag_stop = self.observer.update(epoch, flag_info) # notify observer - return flag_stop - def _cal_phi(self, tensor_x): logits = self.model.cal_logit_y(tensor_x) return logits @@ -60,12 +30,15 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): _ = tensor_d _ = others y = tensor_y - phi = self._cal_phi(tensor_x) - dummy_w_scale = torch.tensor(1.).to(tensor_x.device).requires_grad_() - loss_1 = F.cross_entropy(phi[::2] * dummy_w_scale, y[::2]) - loss_2 = F.cross_entropy(phi[1::2] * dummy_w_scale, y[1::2]) - grad_1 = autograd.grad(loss_1, [dummy_w_scale], create_graph=True)[0] - grad_2 = autograd.grad(loss_2, [dummy_w_scale], create_graph=True)[0] - loss_irm_scalar = torch.sum(grad_1 * grad_2) # scalar - loss_irm_tensor = loss_irm_scalar.expand(tensor_x.shape[0]) - return [loss_irm_tensor], [self.aconf.gamma_reg] + with torch.enable_grad(): + phi = self._cal_phi(tensor_x) + dummy_w_scale = torch.tensor(1.).to(tensor_x.device).requires_grad_() + # interleave instances inside a minibatch + loss_1 = F.cross_entropy(phi[::2] * dummy_w_scale, y[::2]) + loss_2 = F.cross_entropy(phi[1::2] * dummy_w_scale, y[1::2]) + grad_1 = autograd.grad(loss_1, [dummy_w_scale], create_graph=True)[0] + grad_2 = autograd.grad(loss_2, [dummy_w_scale], create_graph=True)[0] + loss_irm_scalar = torch.sum(grad_1 * grad_2) # scalar + loss_irm_scalar = torch.square(loss_irm_scalar) + loss_irm_tensor = loss_irm_scalar.expand(tensor_x.shape[0]) + return [loss_irm_tensor], [self.aconf.gamma_reg] diff --git a/domainlab/algos/trainers/train_irm_sep_dom.py b/domainlab/algos/trainers/train_irm_sep_dom.py new file mode 100644 index 000000000..94d3bca79 --- /dev/null +++ b/domainlab/algos/trainers/train_irm_sep_dom.py @@ -0,0 +1,39 @@ +""" +use random start to generate adversarial images +""" +import torch +from torch import autograd +from torch.nn import functional as F +from domainlab.algos.trainers.train_irm import TrainerIRM + + +class TrainerIRMSepDom(TrainerIRM): + def tr_epoch(self, epoch, flag_info=False): + list_loaders = list(self.dict_loader_tr.values()) + loaders_zip = zip(*list_loaders) + self.model.train() + self.epo_loss_tr = 0 + + for ind_batch, tuple_data_domains_batch in enumerate(loaders_zip): + self.optimizer.zero_grad() + list_domain_loss_erm = [] + list_domain_reg = [] + for batch_domain_e in tuple_data_domains_batch: + tensor_x, tensor_y, tensor_d, *others = batch_domain_e + tensor_x, tensor_y, tensor_d = \ + tensor_x.to(self.device), tensor_y.to(self.device), \ + tensor_d.to(self.device) + list_domain_loss_erm.append( + self.model.cal_task_loss(tensor_x, tensor_y)) + list_1ele_loss_irm, _ = \ + self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + list_domain_reg += list_1ele_loss_irm + loss = torch.sum(torch.stack(list_domain_loss_erm)) + \ + self.aconf.gamma_reg * torch.sum(torch.stack(list_domain_reg)) + loss.backward() + self.optimizer.step() + self.epo_loss_tr += loss.detach().item() + self.after_batch(epoch, ind_batch) + + flag_stop = self.observer.update(epoch, flag_info) # notify observer + return flag_stop diff --git a/domainlab/algos/trainers/zoo_trainer.py b/domainlab/algos/trainers/zoo_trainer.py index 93b40acfe..dcf1785dc 100644 --- a/domainlab/algos/trainers/zoo_trainer.py +++ b/domainlab/algos/trainers/zoo_trainer.py @@ -11,6 +11,7 @@ from domainlab.algos.trainers.train_mldg import TrainerMLDG from domainlab.algos.trainers.train_fishr import TrainerFishr from domainlab.algos.trainers.train_irm import TrainerIRM +from domainlab.algos.trainers.train_irm_sep_dom import TrainerIRMSepDom from domainlab.algos.trainers.train_causIRL import TrainerCausalIRL from domainlab.algos.trainers.train_coral import TrainerCoral @@ -57,6 +58,7 @@ def __call__(self, lst_candidates=None, default=None, lst_excludes=None): chain = TrainerMLDG(chain) chain = TrainerFishr(chain) chain = TrainerIRM(chain) + chain = TrainerIRMSepDom(chain) chain = TrainerHyperScheduler(chain) chain = TrainerFbOpt(chain) chain = TrainerCausalIRL(chain) diff --git a/domainlab/models/a_model_classif.py b/domainlab/models/a_model_classif.py index 1f72eec0a..ea9ef5f47 100644 --- a/domainlab/models/a_model_classif.py +++ b/domainlab/models/a_model_classif.py @@ -243,3 +243,4 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): device = tensor_x.device bsize = tensor_x.shape[0] return [torch.zeros(bsize).to(device)], [0.0] + # return [], [] diff --git a/examples/benchmark/aistat_dial_erm_mhof.yaml b/examples/benchmark/aistat_dial_erm_mhof.yaml new file mode 100644 index 000000000..ab530d72f --- /dev/null +++ b/examples/benchmark/aistat_dial_erm_mhof.yaml @@ -0,0 +1,84 @@ +mode: grid + +output_dir: zoutput/benchmarks/benchmark_mhof_dial + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 200 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + nb4reg_over_task_ratio: + distribution: categorical # name of the distribution + datatype: int + values: # concrete values to choose from + - 0 + - 100 + + mu_init: + distribution: categorical + datatype: float + values: + - 0.000001 + - 0.001 + + mu_clip: + distribution: categorical + datatype: float + values: + - 1 + - 10 + - 100 + - 1000 + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 3 + distribution: uniform + + gamma_reg: + min: 0.01 + max: 1e4 + num: 4 + distribution: loguniform + + +# Test fbopt with different hyperparameter configurations + +fbopt_dial_erm: + model: erm + trainer: fbopt_dial + nb4reg_over_task_ratio: 0 + force_setpoint_change_once: True + shared: + - k_i_gain_ratio + - mu_init + - mu_clip + +dial_erm: + model: erm + nb4reg_over_task_ratio: 0 + trainer: dial + shared: + - gamma_reg diff --git a/examples/benchmark/aistat_irm_erm_mhof.yaml b/examples/benchmark/aistat_irm_erm_mhof.yaml index be8a3c480..f4a508896 100644 --- a/examples/benchmark/aistat_irm_erm_mhof.yaml +++ b/examples/benchmark/aistat_irm_erm_mhof.yaml @@ -1,6 +1,6 @@ mode: grid -output_dir: zoutput/benchmarks/benchmark_fbopt_fishr_erm_pacs +output_dir: zoutput/benchmarks/benchmark_mhof_irm_erm_pacs sampling_seed: 0 startseed: 0 @@ -14,7 +14,7 @@ domainlab_args: dmem: False lr: 5e-5 epos: 500 - epos_min: 200 + epos_min: 10 es: 5 bs: 32 san_check: False @@ -28,12 +28,37 @@ domainlab_args: Shared params: - ini_setpoint_ratio: - min: 0.5 - max: 0.99 - num: 2 - step: 0.05 - distribution: uniform + nb4reg_over_task_ratio: + distribution: categorical # name of the distribution + datatype: int + values: # concrete values to choose from + - 0 + - 1 + - 100 + + mu_init: + distribution: categorical + datatype: float + values: + - 0.000001 + - 0.001 + + mu_clip: + distribution: categorical + datatype: float + values: + - 1 + - 10 + - 100 + - 1000 + + exp_shoulder_clip: + distribution: categorical + datatype: float + values: + - 1 + - 2 + - 5 k_i_gain_ratio: min: 0.1 @@ -46,6 +71,7 @@ Shared params: max: 1e4 num: 4 distribution: loguniform + # 1778 is the largest gamma_reg using this sampling # Test fbopt with different hyperparameter configurations @@ -53,12 +79,17 @@ Shared params: fbopt_irm_erm: model: erm trainer: fbopt_irm + ini_setpoint_ratio: 0.99 + force_setpoint_change_once: True shared: - - ini_setpoint_ratio - k_i_gain_ratio + - nb4reg_over_task_ratio + - mu_clip + - mu_init irm_erm: model: erm trainer: irm shared: - gamma_reg + - nb4reg_over_task_ratio diff --git a/examples/benchmark/aistat_irm_erm_mhof2.yaml b/examples/benchmark/aistat_irm_erm_mhof2.yaml new file mode 100644 index 000000000..f6bc9467e --- /dev/null +++ b/examples/benchmark/aistat_irm_erm_mhof2.yaml @@ -0,0 +1,99 @@ +mode: grid + +output_dir: zoutput/benchmarks/mhof_irm_erm_pacs + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 100 + epos_min: 10 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + nb4reg_over_task_ratio: + distribution: categorical # name of the distribution + datatype: int + values: # concrete values to choose from + - 0 + - 100 + + mu_init: + distribution: categorical + datatype: float + values: + - 0.000001 + - 0.001 + + mu_clip: + distribution: categorical + datatype: float + values: + - 10 + - 100 + - 1000 + + exp_shoulder_clip: + distribution: categorical + datatype: float + values: + - 1 + - 5 + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 2 + distribution: uniform + + gamma_reg: + distribution: categorical + datatype: float + values: + - 0.01 + - 0.1 + - 1 + - 10 + - 100 + - 1000 + + # 1778 is the largest gamma_reg using this sampling + + +# Test fbopt with different hyperparameter configurations + +fbopt_irm_erm: + model: erm + trainer: fbopt_irm + ini_setpoint_ratio: 0.99 + force_setpoint_change_once: True + shared: + - k_i_gain_ratio + - nb4reg_over_task_ratio + - mu_init + - mu_clip + - exp_shoulder_clip + +irm_erm: + model: erm + trainer: irm + shared: + - gamma_reg + - nb4reg_over_task_ratio diff --git a/examples/benchmark/aistat_irm_erm_only.yaml b/examples/benchmark/aistat_irm_erm_only.yaml deleted file mode 100644 index 299ec528c..000000000 --- a/examples/benchmark/aistat_irm_erm_only.yaml +++ /dev/null @@ -1,58 +0,0 @@ -mode: grid - -output_dir: zoutput/benchmarks/benchmark_irm_erm - -sampling_seed: 0 -startseed: 0 -endseed: 2 - -test_domains: - - sketch - -domainlab_args: - tpath: examples/tasks/task_pacs_aug.py - dmem: False - lr: 5e-5 - epos: 500 - epos_min: 200 - es: 5 - bs: 32 - san_check: False - npath: examples/nets/resnet50domainbed.py - npath_dom: examples/nets/resnet50domainbed.py - zx_dim: 0 - zy_dim: 64 - zd_dim: 64 - - - - -Shared params: - ini_setpoint_ratio: - min: 0.5 - max: 0.99 - num: 2 - step: 0.05 - distribution: uniform - - k_i_gain: - min: 0.0001 - max: 0.01 - num: 2 - step: 0.0001 - distribution: uniform - - gamma_reg: - min: 0.01 - max: 1e4 - num: 10 - distribution: loguniform - - -# Test fbopt with different hyperparameter configurations - -irm_erm: - model: erm - trainer: irm - shared: - - gamma_reg diff --git a/examples/benchmark/pacs_fbopt_dial_diva.yaml b/examples/benchmark/aistat_pacs_mhof_dial_diva.yaml similarity index 81% rename from examples/benchmark/pacs_fbopt_dial_diva.yaml rename to examples/benchmark/aistat_pacs_mhof_dial_diva.yaml index ca2cf3921..ba0753e53 100644 --- a/examples/benchmark/pacs_fbopt_dial_diva.yaml +++ b/examples/benchmark/aistat_pacs_mhof_dial_diva.yaml @@ -14,7 +14,7 @@ domainlab_args: dmem: False lr: 5e-5 epos: 500 - epos_min: 200 + epos_min: 20 es: 5 bs: 16 san_check: False @@ -28,13 +28,6 @@ domainlab_args: Shared params: - ini_setpoint_ratio: - min: 0.9 - max: 0.99 - num: 2 - step: 0.05 - distribution: uniform - k_i_gain_ratio: min: 0.01 max: 0.90 @@ -78,16 +71,15 @@ dial_fbopt: trainer: fbopt_dial gamma_y: 1.0 shared: - - ini_setpoint_ratio - k_i_gain_ratio - dial_lr - dial_epsilon -# dial: -# model: diva -# trainer: dial -# shared: -# - dial_lr -# - dial_epsilon -# - gamma_y -# - gamma_d +dial_diva: + model: diva + trainer: dial + shared: + - dial_lr + - dial_epsilon + - gamma_y + - gamma_d diff --git a/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof.yaml b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof.yaml new file mode 100644 index 000000000..35ca668ba --- /dev/null +++ b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof.yaml @@ -0,0 +1,103 @@ +mode: grid + +output_dir: zoutput/benchmarks/mhof_combo_irm_dial_pacs + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 100 + epos_min: 10 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + nb4reg_over_task_ratio: + distribution: categorical # name of the distribution + datatype: int + values: # concrete values to choose from + - 0 + - 100 + + mu_init: + distribution: categorical + datatype: float + values: + - 0.000001 + - 0.001 + + mu_clip: + distribution: categorical + datatype: float + values: + - 10 + - 100 + - 1000 + + exp_shoulder_clip: + distribution: categorical + datatype: float + values: + - 1 + - 5 + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 2 + distribution: uniform + + gamma_reg: + distribution: categorical + datatype: float + values: + - 0.01 + - 0.1 + - 1 + - 10 + - 100 + - 1000 + + +fbopt_irm_dial_erm: + model: erm + trainer: fbopt_irm_dial + ini_setpoint_ratio: 0.99 + force_setpoint_change_once: True + tr_with_init_mu: True + shared: + - k_i_gain_ratio + - nb4reg_over_task_ratio + - mu_init + - mu_clip + - exp_shoulder_clip + +irm_dial: + model: erm + trainer: irm_dial + shared: + - gamma_reg + - nb4reg_over_task_ratio + +irm_dial_feedforward_full: + model: erm + trainer: hyperscheduler_irm_dial + shared: + - gamma_reg + - nb4reg_over_task_ratio diff --git a/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml new file mode 100644 index 000000000..f1221b2f8 --- /dev/null +++ b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml @@ -0,0 +1,99 @@ +mode: grid + +output_dir: zoutput/benchmarks/only_mhof_combo_irm_dial_pacs + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 100 + epos_min: 10 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + nb4reg_over_task_ratio: + distribution: categorical # name of the distribution + datatype: int + values: # concrete values to choose from + - 0 + - 100 + + mu_init: + distribution: categorical + datatype: float + values: + - 0.000001 + - 0.001 + + mu_clip: + distribution: categorical + datatype: float + values: + - 10 + - 100 + - 1000 + + exp_shoulder_clip: + distribution: categorical + datatype: float + values: + - 1 + - 5 + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 2 + distribution: uniform + + gamma_reg: + distribution: categorical + datatype: float + values: + - 0.01 + - 0.1 + - 1 + - 10 + - 100 + - 1000 + + gamma_reg: + distribution: categorical + datatype: str + values: + - "SliderAnyComponent()" + - "SliderAllComponent()" + - "DominateAnyComponent()" + + +fbopt_irm_dial_erm: + model: erm + trainer: fbopt_irm_dial + ini_setpoint_ratio: 0.99 + force_setpoint_change_once: True + coeff_ma_setpoint: 0.0 + tr_with_init_mu: True + nb4reg_over_task_ratio: 0 + shared: + - k_i_gain_ratio + - mu_init + - mu_clip + - exp_shoulder_clip + - str_setpoint_ada diff --git a/tests/test_fbopt_irm.py b/tests/test_fbopt_irm.py new file mode 100644 index 000000000..fb1f109f6 --- /dev/null +++ b/tests/test_fbopt_irm.py @@ -0,0 +1,14 @@ +""" + end-end test +""" +from tests.utils_test import utils_test_algo + + +def test_mhof_irm(): + """ + mhof-irm + """ + args = "--te_d=0 --task=mnistcolor10 --model=erm \ + --trainer=fbopt_irm --nname=conv_bn_pool_2 \ + --k_i_gain_ratio=0.5" + utils_test_algo(args) diff --git a/tests/test_irm.py b/tests/test_irm.py index 235b9e4ce..42b12811f 100644 --- a/tests/test_irm.py +++ b/tests/test_irm.py @@ -12,6 +12,16 @@ def test_irm(): --trainer=irm --nname=alexnet" utils_test_algo(args) +def test_irm_sepdom(): + """ + train with Invariant Risk Minimization + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \ + --trainer=irmsepdom --nname=alexnet" + utils_test_algo(args) + + + def test_irm_scheduler(): """