From 1af128b34dfbb6427b15f2f3c8c39f8864cc5f37 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 2 Nov 2023 17:24:00 +0100 Subject: [PATCH 1/4] user shceduler itself to update model parameter --- domainlab/algos/trainers/train_fbopt_b.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/domainlab/algos/trainers/train_fbopt_b.py b/domainlab/algos/trainers/train_fbopt_b.py index a82201382..e40c09b54 100644 --- a/domainlab/algos/trainers/train_fbopt_b.py +++ b/domainlab/algos/trainers/train_fbopt_b.py @@ -116,7 +116,8 @@ def set_model_with_mu(self): """ set model multipliers """ - self.model.hyper_update(epoch=None, fun_scheduler=HyperSetter(self.hyper_scheduler.mmu)) + # self.model.hyper_update(epoch=None, fun_scheduler=HyperSetter(self.hyper_scheduler.mmu)) + self.model.hyper_update(epoch=None, fun_scheduler=self.hyper_scheduler) def tr_epoch(self, epoch, flag_info=False): """ From fbefd73591040b84bb92c2cb16c44bdd77f772ed Mon Sep 17 00:00:00 2001 From: smilesun Date: Tue, 7 Nov 2023 15:42:34 +0100 Subject: [PATCH 2/4] fix issue #507 --- domainlab/algos/trainers/fbopt_mu_controller.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py index 3193ae225..2429a3bd6 100644 --- a/domainlab/algos/trainers/fbopt_mu_controller.py +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -252,3 +252,6 @@ def update_setpoint(self, epo_reg_loss, epo_task_loss): update setpoint """ return self.set_point_controller.observe(epo_reg_loss, epo_task_loss) + + def __call__(self, epoch): + return self.mmu From bc9de7f1e65142cc04c8fbbdd6a626431c8fdf25 Mon Sep 17 00:00:00 2001 From: smilesun Date: Tue, 7 Nov 2023 15:50:51 +0100 Subject: [PATCH 3/4] towards fix issue #507 --- domainlab/algos/trainers/train_fbopt_b.py | 10 ++++++---- domainlab/models/model_dann.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/domainlab/algos/trainers/train_fbopt_b.py b/domainlab/algos/trainers/train_fbopt_b.py index a478d0880..ac1290703 100644 --- a/domainlab/algos/trainers/train_fbopt_b.py +++ b/domainlab/algos/trainers/train_fbopt_b.py @@ -86,10 +86,12 @@ def before_tr(self): self.flag_setpoint_updated = False if self.aconf.force_feedforward: self.set_scheduler(scheduler=HyperSchedulerWarmup) + self.flag_update_hyper_per_epoch = True + self.hyper_scheduler.set_steps(total_steps=self.aconf.warmup) else: self.set_scheduler(scheduler=HyperSchedulerFeedback) - self.set_model_with_mu() # very small value + self.set_model_with_mu(0) # very small value if self.aconf.tr_with_init_mu: self.tr_with_init_mu() @@ -113,12 +115,12 @@ def tr_with_init_mu(self): """ super().tr_epoch(-1) - def set_model_with_mu(self): + def set_model_with_mu(self, epoch): """ set model multipliers """ # self.model.hyper_update(epoch=None, fun_scheduler=HyperSetter(self.hyper_scheduler.mmu)) - self.model.hyper_update(epoch=None, fun_scheduler=self.hyper_scheduler) + self.model.hyper_update(epoch=epoch, fun_scheduler=self.hyper_scheduler) def tr_epoch(self, epoch, flag_info=False): """ @@ -130,7 +132,7 @@ def tr_epoch(self, epoch, flag_info=False): self.epo_loss_tr, self.list_str_multiplier_na, miter=epoch) - self.set_model_with_mu() + self.set_model_with_mu(epoch) if hasattr(self.model, "dict_multiplier"): logger = Logger.get_logger() logger.info(f"current multiplier: {self.model.dict_multiplier}") diff --git a/domainlab/models/model_dann.py b/domainlab/models/model_dann.py index 292acb8b5..574f9bacc 100644 --- a/domainlab/models/model_dann.py +++ b/domainlab/models/model_dann.py @@ -57,7 +57,7 @@ def __init__(self, list_str_y, list_str_d, self.net_encoder = net_encoder self.net_classifier = net_classifier self.net_discriminator = net_discriminator - + @property def list_str_multiplier_na(self): return ["alpha"] From 7ae8f608dbdc43355a217260e8ea8c4886fac17b Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 9 Nov 2023 09:44:07 +0100 Subject: [PATCH 4/4] add unit test --- tests/test_fbopt.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_fbopt.py b/tests/test_fbopt.py index 3e100acd8..3add87cf6 100644 --- a/tests/test_fbopt.py +++ b/tests/test_fbopt.py @@ -33,3 +33,6 @@ def test_forcesetpoint_fbopt(): args = "--te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --aname=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=10 --es=0 --mu_init=0.00001 --coeff_ma_setpoint=0.5 --coeff_ma_output_state=0.99 --force_setpoint_change_once" utils_test_algo(args) +def test_forcefeedforward_fbopt(): + args = "--te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --aname=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=2000 --epos_min=100 --es=1 --force_feedforward" + utils_test_algo(args)