Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update a_trainer.py: change mu order #890

Open
wants to merge 51 commits into
base: mhof_dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
b2b204d
copy branch fbopt_vector_ki_gain to mhof_dev
smilesun Oct 4, 2024
ca9766e
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 5, 2024
731a144
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 6, 2024
9e09396
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 6, 2024
b4ddafe
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 6, 2024
d53c863
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 6, 2024
f15478c
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 8, 2024
6626475
use self.decoratee instead of self.model
smilesun Oct 8, 2024
d5d3a0b
cmd script to test mhof irm
smilesun Oct 8, 2024
6701cbf
enable grad for irm inside torch.no_grad for mhof
smilesun Oct 9, 2024
888d714
filter out zero reg loss in abstract trainer
smilesun Oct 9, 2024
64bcc9c
trainer behaves like model, now decoratte's cal_loss has to be changed
smilesun Oct 9, 2024
22f343c
overwrite multiplier from scheduler(default static scheduler then no …
smilesun Oct 9, 2024
23607fb
per domain irm to separate file
smilesun Oct 9, 2024
28bdfa9
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 9, 2024
63ad47b
dial mhof yaml
smilesun Oct 9, 2024
8ba7e11
number of bathces to estimate ratio
smilesun Oct 9, 2024
6073403
Update aistat_irm_erm_only.yaml
smilesun Oct 9, 2024
8bd1163
Update aistat_irm_erm_mhof.yaml
smilesun Oct 9, 2024
7667cbc
.
smilesun Oct 9, 2024
795478f
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 10, 2024
b64e5d0
Update aistat_irm_erm_only.yaml
smilesun Oct 10, 2024
6433f58
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 10, 2024
6f5bac1
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 10, 2024
ba68e27
fix num_batches for loss ratio estimate
smilesun Oct 10, 2024
c09879d
.
smilesun Oct 10, 2024
401e978
Merge branch 'mhof_dev' into mhof_dev_vector_ki_gain
smilesun Oct 10, 2024
5585b4a
use correct hyper range
smilesun Oct 10, 2024
d176e1a
Update aistat_irm_erm_mhof.yaml
smilesun Oct 10, 2024
3a4226e
add unit test
smilesun Oct 10, 2024
011cccb
Update aistat_irm_erm_mhof.yaml
smilesun Oct 10, 2024
ae03219
Update and rename pacs_fbopt_dial_diva.yaml to aistat_pacs_mhof_dial_…
smilesun Oct 10, 2024
3166eb2
dial
smilesun Oct 10, 2024
d8cee2d
fix partially issue #777
smilesun Oct 11, 2024
9884640
detailed doc for mhof args, change str args to boolean, fix issue #777
smilesun Oct 11, 2024
cd15915
new irm yaml file
smilesun Oct 11, 2024
273422b
.
smilesun Oct 11, 2024
dbbda55
take square of irm loss, copy reg loss from decoratee to fbopt
smilesun Oct 11, 2024
1cb2428
.
smilesun Oct 11, 2024
155f0df
.
smilesun Oct 11, 2024
c41e92b
.
smilesun Oct 11, 2024
b8731cb
Milestone: feedforward works now with trainers
smilesun Oct 13, 2024
f4e4773
yaml for feedforward
smilesun Oct 13, 2024
e3ed8c6
tr_with_init_mu
smilesun Oct 15, 2024
e38fa75
doc
smilesun Oct 15, 2024
044e3cc
.
smilesun Oct 15, 2024
aea6e90
no ma for setpoint
smilesun Oct 17, 2024
e5197f3
logger
smilesun Oct 17, 2024
dc58a50
setpoint ada as argument
smilesun Oct 17, 2024
9ddd037
yaml file search setpoint ada
smilesun Oct 17, 2024
83143a9
Update a_trainer.py: change mu order
smilesun Oct 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions a_test_feedforward_irm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=hyperscheduler_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=10 --epos_min=4 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.99999999
2 changes: 1 addition & 1 deletion a_test_mhof_irm.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
python main_out.py --te_d=0 --task=mnistcolor10 --model=erm --trainer=fbopt_irm --nname=conv_bn_pool_2 --k_i_gain_ratio=0.5
python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=fbopt_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=500 --epos_min=400 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.9 --nb4reg_over_task_ratio=0 --tr_with_init_mu --coeff_ma_setpoint=0.0 --str_setpoint_ada="SliderAnyComponent()"
63 changes: 58 additions & 5 deletions domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, successor_node=None, extend=None):
"""
super().__init__(successor_node)
self._model = None
# decoratee can be both model or trainer
self._decoratee = extend
self.task = None
self.observer = None
Expand Down Expand Up @@ -96,6 +97,8 @@ def __init__(self, successor_node=None, extend=None):
self._ma_iter = 0
#
self.list_reg_over_task_ratio = None
# mhof
self.dict_multiplier = {}


@property
Expand Down Expand Up @@ -199,7 +202,13 @@ def before_tr(self):
"""
before training, probe model performance
"""
self.cal_reg_loss_over_task_loss_ratio()
list_mu = self.cal_reg_loss_over_task_loss_ratio()
self.dict_multiplier = {'mu4regloss'+str(i): value for i, value in enumerate(list_mu)}

@property
def list_str_multiplier_na(self):
list_str = list(self.dict_multiplier.keys())
return list_str

def cal_reg_loss_over_task_loss_ratio(self):
"""
Expand All @@ -208,18 +217,21 @@ def cal_reg_loss_over_task_loss_ratio(self):
"""
list_accum_reg_loss = []
loss_task_agg = 0
list_mu = None
for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate(
self.loader_tr
):
if ind_batch >= self.aconf.nb4reg_over_task_ratio:
return
tensor_x, tensor_y, tensor_d = (
tensor_x.to(self.device),
tensor_y.to(self.device),
tensor_d.to(self.device),
)
list_reg_loss_tensor, _ = \
list_reg_loss_tensor, list_mu = \
self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others)

if ind_batch >= self.aconf.nb4reg_over_task_ratio:
return list_mu

list_reg_loss_tensor = [torch.sum(tensor).detach().item()
for tensor in list_reg_loss_tensor]
if ind_batch == 0:
Expand All @@ -235,6 +247,7 @@ def cal_reg_loss_over_task_loss_ratio(self):
loss_task_agg += tensor_loss_task
self.list_reg_over_task_ratio = [reg_loss / loss_task_agg
for reg_loss in list_accum_reg_loss]
return list_mu

def post_tr(self):
"""
Expand Down Expand Up @@ -301,7 +314,27 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
list_loss_tensor = list_reg_loss_model_tensor + \
list_reg_loss_trainer_tensor
list_mu = list_mu_model + list_mu_trainer
return list_loss_tensor, list_mu
# ERM return a tensor of all zeros, delete here
if len(list_mu) > 1:
list_boolean_zero = [torch.all(torch.eq(list_loss_tensor[i], 0)).item()
for i in range(len(list_mu))]
list_loss_tensor = [list_loss_tensor[i] for (i, flag) in
enumerate(list_boolean_zero) if not flag]
list_mu = [list_mu[i] for (i, flag) in enumerate(list_boolean_zero) if not flag]
if self.dict_multiplier:
#list_mu = list(self.dict_multiplier.values())
# use unique order of keys to keep the order of list_mu consistent each time
list_mu = [self.dict_multiplier[key] for key in sorted(self.dict_multiplier.keys())]

list_loss_tensor_normalized = list_loss_tensor
if self.list_reg_over_task_ratio:
assert len(list_mu) == len(self.list_reg_over_task_ratio)
list_loss_tensor_normalized = \
[reg_loss / reg_over_task_ratio if reg_over_task_ratio != 0
else reg_loss for (reg_loss, reg_over_task_ratio)
in zip(list_loss_tensor, self.list_reg_over_task_ratio)]

return list_loss_tensor_normalized, list_mu

def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
"""
Expand All @@ -326,3 +359,23 @@ def print_parameters(self):
"""
params = vars(self)
print(f"Parameters of {type(self).__name__}: {params}")

def hyper_init(self, functor_scheduler, trainer):
"""
initialize both trainer's multiplier and model's multiplier
"""
if not self.dict_multiplier:
raise RuntimeError("self.dict_multiplier empty!")
return functor_scheduler(
trainer=trainer, **self.dict_multiplier
)

def hyper_update(self, epoch, fun_scheduler):
"""hyper_update.

:param epoch:
:param fun_scheduler:
"""
dict_rst = fun_scheduler(epoch)
for key in self.dict_multiplier:
self.dict_multiplier[key] = dict_rst[key]
72 changes: 51 additions & 21 deletions domainlab/algos/trainers/args_fbopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,61 +9,81 @@ def add_args2parser_fbopt(parser):
"""

parser.add_argument(
"--k_i_gain", type=float, default=0.001, help="PID control gain for integrator"
"--k_i_gain", type=float, default=0.001,
help="PID control gain for integrator, if k_i_gain_ratio is not None, \
then this value will be overwriten, see doc for k_i_gain_ratio"
)

parser.add_argument(
"--k_i_gain_ratio",
type=float,
default=None,
help="set k_i_gain to be ratio of \
initial saturation k_i_gain",
help="set k_i_gain to be ratio of initial saturation k_i_gain \
which K_I * delta = exp_shoulder_clip(saturation value), solve \
for K_I, where delta = reg loss - setpoint. \
Now independent of the scale of delta, the K_I gain will be set so \
that the multiplier will be increased at a rate defined by \
exp_shoulder_clip",
)

parser.add_argument(
"--mu_clip", type=float, default=1e4, help="maximum value of mu"
"--mu_clip", type=float, default=1e4,
help="maximum value of mu: mu_clip should be large enough so that the \
regularization loss as penalty can be weighed superior enough to \
decrease."
)

parser.add_argument(
"--mu_min", type=float, default=1e-6, help="minimum value of mu"
"--mu_min", type=float, default=1e-6, help="minimum value of mu, mu \
can not be negative"
)

parser.add_argument(
"--mu_init", type=float, default=0.001, help="initial beta for multiplication"
"--mu_init", type=float, default=0.001,
help="initial value for each component of the multiplier vector"
)

parser.add_argument(
"--coeff_ma", type=float, default=0.5, help="exponential moving average"
"--coeff_ma", type=float, default=0.5,
help="exponential moving average of delta \
(reg minus setpoint as control error): \
move_ave=move_ave + coeff*delta(current value)"
)

parser.add_argument(
"--coeff_ma_output_state",
type=float,
default=0.1,
help="state exponential moving average of \
reguarlization loss",
help="output (reguarization loss) exponential moving average \
move_ave=move_ave*coeef + reg(current value)",
)

parser.add_argument(
"--coeff_ma_setpoint",
type=float,
default=0.9,
help="setpoint average coeff for previous setpoint",
help="setpoint average (coeff for previous setpoint)",
)

parser.add_argument(
"--exp_shoulder_clip",
type=float,
default=5,
help="clip before exponential operation",
help="clip delta(control error): \
R(reg-loss)-b(setpoint) before exponential operation: \
exp[clip(R-b, exp_shoulder_clip)].\
exponential magnifies control error, so this argument \
defines the maximum rate of change of multipliers \
exp(5)=148, exp_shoulder_clip should not be too big, \
if exp_shoulder_clip is small, then more like exterior point method",
)

parser.add_argument(
"--ini_setpoint_ratio",
type=float,
default=0.99,
help="before training start, evaluate reg loss, \
setpoint will be 0.9 of this loss",
setpoint will be 0.9 of this loss",
)

parser.add_argument(
Expand All @@ -77,8 +97,8 @@ def add_args2parser_fbopt(parser):
"--force_setpoint_change_once",
action="store_true",
default=False,
help="train until the setpoint changed at least once \
up to maximum epos specified",
help="continue training until the setpoint changed at least once: \
up to maximum epos specified",
)

parser.add_argument(
Expand All @@ -103,24 +123,34 @@ def add_args2parser_fbopt(parser):
)

parser.add_argument(
"--overshoot_rewind",
type=str,
default="yes",
help="overshoot_rewind, for benchmark, use yes or no",
"--no_overshoot_rewind",
action="store_true",
default=False,
help="disable overshoot rewind: when reg loss satisfies setpoint \
already, then set activation=K_I*delta = 0",
)

parser.add_argument(
"--setpoint_rewind",
action="store_true",
default=False,
help="rewind setpoint",
)

# this arg is only used when model is set to be "diva"
parser.add_argument(
"--str_setpoint_ada",
type=str,
default="no",
help="setpoing_rewind, for benchmark, use yes or no",
default="DominateAllComponent()",
help="which setpoint adaptation criteria to use",
)

# this arg is only used when model is set to be "diva"
parser.add_argument(
"--str_diva_multiplier_type",
type=str,
default="gammad_recon",
help="which penalty to tune",
help="which penalty to tune, only useful to DIVA model",
)

return parser
32 changes: 23 additions & 9 deletions domainlab/algos/trainers/fbopt_mu_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,16 @@ def __init__(self, trainer, **kwargs):
self.mmu = {key: self.init_mu for key, val in self.mmu.items()}
self.set_point_controller = FbOptSetpointController(args=self.trainer.aconf)

self.k_i_control = trainer.aconf.k_i_gain
self.k_i_gain_ratio = None
self.overshoot_rewind = trainer.aconf.overshoot_rewind == "yes"
self.k_i_control = [trainer.aconf.k_i_gain for i in
range(len(self.mmu))]
self.k_i_gain_ratio = trainer.aconf.k_i_gain_ratio
self.overshoot_rewind = not trainer.aconf.no_overshoot_rewind
self.delta_epsilon_r = None

# NOTE: this value will be set according to initial evaluation of
# neural network
self.activation_clip = trainer.aconf.exp_shoulder_clip
self.coeff_ma = trainer.aconf.coeff_ma
self.coeff4newval_ma_delta = trainer.aconf.coeff_ma
# NOTE:
# print(copy.deepcopy(self.model))
# TypeError: cannot pickle '_thread.lock' object
Expand All @@ -70,21 +71,33 @@ def __init__(self, trainer, **kwargs):

def set_k_i_gain(self, epo_reg_loss):
if self.k_i_gain_ratio is None:
return
if self.k_i_control:
return
raise RuntimeError("set either direct k_i_control value or \
set k_i_gain_ratio, can not be both empty!")
# NOTE: do not use self.cal_delta4control!!!! which will change
# class member variables self.delta_epsilon_r!
list_setpoint = self.get_setpoint4r()
if_list_sign_agree(epo_reg_loss, list_setpoint)
delta_epsilon_r = [a - b for a, b in zip(epo_reg_loss, list_setpoint)]

# to calculate self.delta_epsilon_r
list_active = [self.activation_clip for i in range(len(delta_epsilon_r))]

k_i_gain_saturate = [
a / b for a, b in zip(self.activation_clip, delta_epsilon_r)
a / b for a, b in zip(list_active, delta_epsilon_r)
]

# FIXME: add max K_I gain here if initial delta is too small

k_i_gain_saturate_min = min(k_i_gain_saturate)
# NOTE: here we override the commandline arguments specification
# for k_i_control, so k_i_control is not a hyperparameter anymore
self.k_i_control = self.k_i_gain_ratio * k_i_gain_saturate_min
# self.k_i_control = [self.k_i_gain_ratio * ele for ele in k_i_gain_saturate]
# k_I should be the same for each component, the control error already
# make the multiplier magnification different
self.k_i_control = [self.k_i_gain_ratio * k_i_gain_saturate_min for i
in range(len(delta_epsilon_r))]
warnings.warn(
f"hyperparameter k_i_gain disabled! \
replace with {self.k_i_control}"
Expand Down Expand Up @@ -118,7 +131,8 @@ def cal_delta4control(self, list1, list_setpoint):
# self.delta_epsilon_r is the previous time step.
# delta_epsilon_r is the current time step
self.delta_epsilon_r = self.cal_delta_integration(
self.delta_epsilon_r, delta_epsilon_r, self.coeff_ma
self.delta_epsilon_r, delta_epsilon_r,
self.coeff4newval_ma_delta
)

def cal_delta_integration(self, list_old, list_new, coeff):
Expand Down Expand Up @@ -162,7 +176,7 @@ def cal_activation(self):
"""
setpoint = self.get_setpoint4r()
activation = [
self.k_i_control * val if setpoint[i] > 0 else self.k_i_control * (-val)
self.k_i_control[i] * val if setpoint[i] > 0 else self.k_i_control[i] * (-val)
for i, val in enumerate(self.delta_epsilon_r)
]
if self.activation_clip is not None:
Expand Down
18 changes: 14 additions & 4 deletions domainlab/algos/trainers/fbopt_setpoint_ada.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self, host):
self.counter = None
self.epo_ma = None
self.ref = None
self.coeff_ma = 0.5
self.coeff_ma_setpoint_rewinder = 0.5
self.setpoint_rewind = host.flag_setpoint_rewind

def reset(self, epo_reg_loss):
Expand All @@ -98,7 +98,7 @@ def observe(self, epo_reg_loss):
"""
if self.ref is None:
self.reset(epo_reg_loss)
self.epo_ma = list_ma(self.epo_ma, epo_reg_loss, self.coeff_ma)
self.epo_ma = list_ma(self.epo_ma, epo_reg_loss, self.coeff_ma_setpoint_rewinder)
list_comparison_increase = [a < b for a, b in zip(self.ref, self.epo_ma)]
list_comparison_above_setpoint = [
a < b for a, b in zip(self.host.setpoint4R, self.epo_ma)
Expand Down Expand Up @@ -144,9 +144,11 @@ def __init__(self, state=None, args=None):
if args is not None and args.no_setpoint_update:
state = FixedSetpoint()
else:
state = DominateAllComponent()
# state = eval('DominateAllComponent()')
# state = DominateAllComponent()
state = eval(args.str_setpoint_ada)
self.transition_to(state)
self.flag_setpoint_rewind = args.setpoint_rewind == "yes"
self.flag_setpoint_rewind = args.setpoint_rewind
self.setpoint_rewinder = SetpointRewinder(self)
self.state_task_loss = 0.0
self.state_epo_reg_loss = [
Expand Down Expand Up @@ -311,4 +313,12 @@ def update_setpoint(self):
{self.host.state_task_loss}"
)
self.host.setpoint4ell = self.host.state_task_loss

if flag1 & flag2:
logger = Logger.get_logger(logger_name="main_out_logger", loglevel="INFO")
logger.info(
"!!!!!!!!!In DominantAllComponent: \
besides each components of reg loss shrinks \
task loss also decreased which forms dominance!"
)
return flag1 & flag2, list_pos
Loading