Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fbopt #462

Open
wants to merge 1,051 commits into
base: master
Choose a base branch
from
Open

Fbopt #462

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
1051 commits
Select commit Hold shift + click to select a range
cf0d6c4
autoki diva
smilesun Oct 27, 2023
33dd359
yaml
smilesun Oct 27, 2023
5ba7392
beter plot
smilesun Oct 27, 2023
cb86c44
better plot
smilesun Oct 27, 2023
3cea9bd
.
smilesun Oct 27, 2023
e6cac85
more ki gain
smilesun Oct 27, 2023
f029f14
yaml
smilesun Oct 27, 2023
41168f9
2.2
smilesun Oct 27, 2023
4d8f8e2
.
smilesun Oct 27, 2023
cc19e66
Merge branch 'fbopt' into fbopt_dial
smilesun Oct 28, 2023
16da3dd
diva yaml mnisT
smilesun Oct 30, 2023
87389df
improve epos_min jigenauto ki
smilesun Oct 30, 2023
6653ab2
Merge branch 'master' into fbopt
smilesun Oct 31, 2023
125d142
force setpoint to be met as an option fix issue # 604
smilesun Oct 31, 2023
d3524c3
force setpoing in script
smilesun Oct 31, 2023
e7d45dc
Merge branch 'master' into fbopt
smilesun Oct 31, 2023
9c32106
more epochs
smilesun Oct 31, 2023
ab05c06
Merge branch 'fbopt' of github.com:marrlab/DomainLab into fbopt
smilesun Oct 31, 2023
a2553b3
jigen
smilesun Oct 31, 2023
e161df5
Update pacs_jigen_fbopt_alone_autoki.yaml
smilesun Oct 31, 2023
350d352
https://github.com/marrlab/DomainLab/issues/601
agisga Oct 31, 2023
30a5b50
pacs with aut
smilesun Oct 31, 2023
4f29ec4
fix issue #610
smilesun Oct 31, 2023
b7d8528
jigen yaml
smilesun Oct 31, 2023
c2d5015
set arrow head size dynamically
agisga Oct 31, 2023
748c510
Merge pull request #612 from marrlab/phase_portrait_tweaks
smilesun Oct 31, 2023
7065292
Merge pull request #613 from marrlab/phase_portrait_tweaks
smilesun Oct 31, 2023
b085b29
yaml
smilesun Nov 1, 2023
822b42d
force_setpoint_change_once
smilesun Nov 1, 2023
55380ab
Merge branch 'master' into fbopt
smilesun Nov 1, 2023
9e9c37d
increae iteratino jigen
smilesun Nov 1, 2023
2cfebe1
basline jigen
smilesun Nov 2, 2023
107f595
enable san check
smilesun Nov 2, 2023
20582a2
Merge branch 'master' into fbopt
smilesun Nov 2, 2023
946986e
force feedforwrad in feedback
smilesun Nov 2, 2023
56e59e4
Merge pull request #624 from marrlab/fbopt_force_feedforward
smilesun Nov 2, 2023
0ca02c6
use aug in diva
smilesun Nov 2, 2023
d605670
.
smilesun Nov 2, 2023
0e4030f
Merge branch 'master' into fbopt
smilesun Nov 3, 2023
6fd6107
.
smilesun Nov 3, 2023
0f1e851
.
smilesun Nov 3, 2023
439e6de
Update pacs_diva_fbopt_alone_es1_autoki_output_ma_9.yaml
smilesun Nov 3, 2023
77c9f66
Update pacs_diva_fbopt_alone_es1_autoki_output_ma_9.yaml
smilesun Nov 3, 2023
20d62d3
.
smilesun Nov 3, 2023
c272ec5
new command for list errors
smilesun Nov 6, 2023
4f87892
Merge remote-tracking branch 'refs/remotes/origin/fbopt' into fbopt
smilesun Nov 6, 2023
c44a529
200 epochs for diva
smilesun Nov 6, 2023
3b580e2
sbmit comamnd update
smilesun Nov 6, 2023
f8a15cd
Merge branch 'master' into fbopt
smilesun Nov 6, 2023
849d3c2
no flip for jigen
smilesun Nov 6, 2023
77f4c5e
Merge branch 'master' into fbopt
smilesun Nov 6, 2023
738c667
hduva
smilesun Nov 6, 2023
33f9377
jigen no flip yaml
smilesun Nov 6, 2023
01b229e
add missing attribute for hduva list of names for reg loss
smilesun Nov 6, 2023
f87d141
less bs for hduva
smilesun Nov 6, 2023
d160066
add erm to no flip augmentation
smilesun Nov 6, 2023
2085657
find hyperindx from slurm id
smilesun Nov 6, 2023
3a0855a
Merge branch 'fbopt' of github.com:marrlab/DomainLab into fbopt
smilesun Nov 6, 2023
c2c4fc2
fix issue #528 print model multiplier from model instead of scheduler
smilesun Nov 7, 2023
82afc2a
smalelr batch size
smilesun Nov 9, 2023
b5dcff5
fix issue # 636
smilesun Nov 10, 2023
d226b28
.
smilesun Nov 10, 2023
4e4dac1
hduva basleine
smilesun Nov 15, 2023
85c69a0
.
smilesun Nov 15, 2023
f6b0f80
doc
smilesun Nov 16, 2023
01810c7
not sure if we really need wrapper
smilesun Nov 16, 2023
088cc9c
removed wrapper, memory too small
smilesun Nov 16, 2023
3230af0
.
smilesun Nov 16, 2023
a5d252d
lower mu clip hduva
smilesun Nov 17, 2023
c752b0a
Merge branch 'fbopt' into fbopt_matchduva
smilesun Nov 17, 2023
8a4e464
defautl extrac feat to cal logit
smilesun Nov 17, 2023
37aa8d5
doc
smilesun Nov 17, 2023
845dcd3
no need for warp
smilesun Nov 17, 2023
fde0117
Merge branch 'fbopt' into fbopt_matchduva
smilesun Nov 17, 2023
7f82bc9
remove redundancy
smilesun Nov 17, 2023
23d2c51
matchduva
smilesun Nov 17, 2023
03cdb2e
Merge branch 'master' into fbopt
smilesun Nov 20, 2023
634861d
Merge branch 'fbopt' into fbopt_matchduva
smilesun Nov 20, 2023
6e5bc86
Merge branch 'master' into fbopt
smilesun Nov 20, 2023
ee79c91
merge fbopt
smilesun Nov 20, 2023
af67865
Merge pull request #641 from marrlab/fbopt_matchduva
smilesun Nov 22, 2023
c73885e
.
smilesun Nov 22, 2023
2088d5f
Merge branch 'fbopt' into fbopt_dial
smilesun Nov 23, 2023
eb3890a
import
smilesun Nov 23, 2023
0490317
.
smilesun Nov 23, 2023
04af111
correct observer api
smilesun Nov 23, 2023
f448677
Merge branch 'master' into fbopt
smilesun Nov 23, 2023
6a1dc10
Merge branch 'fbopt' into fbopt_dial
smilesun Nov 23, 2023
322bf5a
as_model in mk_opt
smilesun Nov 23, 2023
8f4b9c1
Merge branch 'fbopt_dial' of github.com:marrlab/DomainLab into fbopt_…
smilesun Nov 23, 2023
6f3fd84
.
smilesun Nov 23, 2023
59e0ee0
Merge branch 'master' into fbopt
smilesun Nov 23, 2023
539593c
Merge branch 'fbopt' into fbopt_dial
smilesun Nov 23, 2023
4ac2c2f
fbot has to use as_model explicitly
smilesun Nov 27, 2023
9f73eab
dial diva
smilesun Nov 27, 2023
b2ef567
delete unecenarY
smilesun Nov 27, 2023
2e3cb62
seems working
smilesun Nov 27, 2023
83a3b12
use cpu
smilesun Nov 27, 2023
d56f5a0
redefine self.model
smilesun Nov 28, 2023
08eb8b6
delete all unnecessary wraper function
smilesun Nov 28, 2023
3dda459
remove further uncessiary wrapper
smilesun Nov 28, 2023
3a91cf4
add model setter
smilesun Nov 28, 2023
9ad5b8b
dial only task loss as regularization
smilesun Nov 28, 2023
69f761d
Merge pull request #391 from marrlab/fbopt_dial
smilesun Nov 28, 2023
51ec54a
model to device
smilesun Nov 29, 2023
21de8a2
cal task loss correct api
smilesun Nov 29, 2023
81059da
better format
smilesun Nov 29, 2023
154ac0b
Merge branch 'master' into fbopt
smilesun Nov 29, 2023
c8bd527
Merge branch 'master' into fbopt
smilesun Nov 29, 2023
3ab0b2e
Merge branch 'master' into fbopt
smilesun Nov 30, 2023
7eb6b6b
Merge branch 'master' into fbopt
smilesun Nov 30, 2023
bc899eb
delete redundancy
smilesun Nov 30, 2023
b7d3810
Merge branch 'master' into fbopt
smilesun Nov 30, 2023
b5c42bb
Merge branch 'master' into fbopt
smilesun Nov 30, 2023
ba52e22
Merge branch 'master' into fbopt
smilesun Dec 1, 2023
cea6303
Merge branch 'master' into fbopt
smilesun Dec 1, 2023
c207e71
Merge branch 'master' into fbopt
smilesun Dec 1, 2023
0987721
Merge branch 'master' into fbopt
smilesun Dec 1, 2023
237ba5b
flag_info=False
smilesun Dec 1, 2023
dd09077
after repoch with flag_info
smilesun Dec 1, 2023
68e8d4d
Merge branch 'master' into fbopt
smilesun Dec 2, 2023
ad8ea49
Merge branch 'master' into fbopt
smilesun Dec 2, 2023
9800709
Merge branch 'master' into fbopt
smilesun Dec 3, 2023
7be1ba3
Merge branch 'master' into fbopt
smilesun Dec 4, 2023
4d8e434
use others in fbopt
smilesun Dec 4, 2023
3b7b78f
Merge branch 'master' into fbopt
smilesun Dec 4, 2023
3d63caa
remove as_model
smilesun Dec 4, 2023
9f8993a
Merge branch 'master' into fbopt
smilesun Dec 4, 2023
74f6d11
identifies the thread_lock problem from writer, next step is to refac…
smilesun Dec 4, 2023
d2fe043
exp pre deepcopy to avoid thread lock
smilesun Dec 4, 2023
694859f
Merge branch 'master' into fbopt
smilesun Dec 5, 2023
9c9631a
Merge branch 'master' into fbopt
smilesun Dec 7, 2023
bc8f925
Merge branch 'master' into fbopt
smilesun Dec 7, 2023
37456f1
change new name for warmup linear
smilesun Dec 7, 2023
24b24b9
Merge branch 'master' into fbopt
smilesun Dec 11, 2023
badd404
Merge branch 'master' into fbopt
smilesun Dec 15, 2023
c36a7ef
Merge branch 'master' into fbopt
smilesun Dec 15, 2023
01ed0f3
Merge branch 'master' into fbopt
smilesun Dec 15, 2023
232da40
aname to model, matchdg removel, jigen error
smilesun Dec 17, 2023
b7ef211
Merge branch 'master' into fbopt
smilesun Dec 17, 2023
67a7ff8
Merge branch 'master' into fbopt
smilesun Dec 17, 2023
a54a644
Merge branch 'master' into fbopt
smilesun Dec 18, 2023
b733aab
erm
smilesun Dec 18, 2023
9a44ad4
Merge branch 'master' into fbopt
smilesun Dec 18, 2023
bb14023
Merge branch 'master' into fbopt
smilesun Dec 22, 2023
233c311
merge hduva nn name change
smilesun Dec 22, 2023
da31dde
Merge branch 'master' into fbopt
smilesun Dec 22, 2023
284b729
Merge branch 'master' into fbopt
smilesun Dec 30, 2023
a8d5681
Merge branch 'master' into fbopt
smilesun Jan 6, 2024
d64fbd5
Merge branch 'master' into fbopt
smilesun Jan 6, 2024
28983c6
Merge branch 'master' into fbopt
smilesun Jan 6, 2024
ccbcd45
Merge branch 'master' into fbopt
smilesun Jan 7, 2024
e768291
Merge branch 'master' into fbopt
smilesun Jan 7, 2024
ec41611
Merge branch 'master' into fbopt
smilesun Jan 8, 2024
628c5ae
Merge branch 'master' into fbopt
smilesun Jan 9, 2024
f470432
Merge branch 'master' into fbopt
smilesun Jan 10, 2024
472e25e
Merge branch 'master' into fbopt
smilesun Jan 13, 2024
eab26f8
first not working commit to combine diva with matchdg
smilesun Jan 16, 2024
3140758
fix merge conflict
smilesun Jan 17, 2024
a3e36fe
fix syntax
smilesun Jan 17, 2024
100c687
fix diva
smilesun Jan 17, 2024
2d70f93
Merge branch 'fbopt' into fbopt_matchdiva
smilesun Jan 17, 2024
a3f3559
code style
ntchen Jan 17, 2024
2ec560e
c_msel_val_top_k
ntchen Jan 17, 2024
96e0b44
c_msel_setpoin_delay
ntchen Jan 17, 2024
791d4b7
resolve too many attributes
ntchen Jan 17, 2024
30393a3
ignore too many local variables
ntchen Jan 17, 2024
3a077f6
code style
ntchen Jan 17, 2024
c39fb00
pylint
ntchen Jan 17, 2024
fb93437
pylint
ntchen Jan 17, 2024
b226ffb
Merge pull request #750 from marrlab/foptcodacy
smilesun Jan 18, 2024
898af0b
Merge branch 'fbopt' into fbopt_matchdiva
smilesun Jan 18, 2024
358c936
Merge branch 'master' into fbopt
smilesun Jan 18, 2024
7c92b5e
Merge pull request #746 from marrlab/fbopt_matchdiva
smilesun Jan 18, 2024
aa1d5cd
Merge branch 'master' into fbopt
smilesun Jan 19, 2024
2319105
Update pyproject.toml
smilesun Jan 19, 2024
04a6f25
Update ci.yml: revert how pytest runs: no poetry
smilesun Jan 19, 2024
1298519
Update pyproject.toml: rm tensorboard = "^2.14.0"
smilesun Jan 19, 2024
0248bb0
Merge branch 'master' into fbopt
smilesun Jan 20, 2024
5bd00d6
Update ci.yml
smilesun Jan 21, 2024
108c2e7
Update ci.yml
smilesun Jan 21, 2024
d615736
Merge branch 'master' into fbopt
smilesun Jan 21, 2024
00a30af
add tensorboard to pyproject
smilesun Jan 21, 2024
6dbe7d6
update poetry.lock file
smilesun Jan 22, 2024
486d81a
update dependencies
agisga Jan 23, 2024
6a39fb6
Merge pull request #765 from marrlab/issue_754
smilesun Jan 24, 2024
b7d21db
Merge branch 'master' into fbopt
smilesun Jan 24, 2024
2cd16d8
Merge branch 'master' into fbopt
smilesun Jan 29, 2024
1c4dfd1
reduce nubmer of iterations
smilesun Jan 29, 2024
eb812c2
.
smilesun Jan 29, 2024
1374b33
use aug pac
smilesun Feb 1, 2024
b763838
Update pacs_diva_fbopt_alone_es1.yaml: batch 64 to 32
smilesun Feb 1, 2024
6d34ab4
Update pacs_diva_fbopt_alone_es1_autoki.yaml
smilesun Feb 1, 2024
a3fa382
es=10
smilesun Feb 6, 2024
964c892
.
smilesun Feb 6, 2024
01e03bb
merge conflict with master
smilesun Feb 6, 2024
f224039
Merge branch 'master' into fbopt
smilesun Feb 7, 2024
4c69a52
force setpoint change once
smilesun Feb 7, 2024
775e41f
Update pacs_diva_fbopt_alone_es1_autoki.yaml
smilesun Feb 8, 2024
600e39e
attempt to reproduce previous diva results on pacs data
agisga Feb 8, 2024
b58cbff
Merge pull request #778 from marrlab/diva_pacs_reproduce
smilesun Feb 8, 2024
1c3f26a
1e-5 type notation not supported for yaml file
agisga Feb 8, 2024
8b9d921
.
smilesun Feb 8, 2024
180f2b8
better yaml for benchmark
smilesun Feb 9, 2024
6541cd4
refine yaml file
smilesun Feb 9, 2024
e05081a
Rename pacs_diva_fbopt_alone_es1.yaml to pacs_diva_fbopt_alone_es1_ra…
smilesun Feb 13, 2024
5bf06e0
Merge branch 'master' into fbopt
smilesun Feb 14, 2024
b0e679f
add gamma-y sample
smilesun Feb 15, 2024
99c9ad2
minor improvements to generation of figures for the fbopt experiments
agisga Feb 15, 2024
e38e9b5
add class to put tensorboard data to text file
smilesun Feb 23, 2024
e6ffee6
make fbopt script into separate folder
smilesun Feb 23, 2024
9ecb520
improve code
smilesun Feb 23, 2024
eee82b4
Merge branch 'fbopt' into fbopt_visualizations
agisga Feb 23, 2024
7b13762
Merge pull request #782 from marrlab/fbopt_visualizations
agisga Feb 23, 2024
8cca9cf
minor adjustments to the visualizations
agisga Feb 23, 2024
6dd6b2d
fbopt figures: backup x and y data to txt files
agisga Feb 23, 2024
e538970
bug fix for the last commit
agisga Feb 23, 2024
76b73a3
fbopt plots: fixes txt saving
agisga Feb 23, 2024
f5c06a7
single run reproduce
smilesun Feb 26, 2024
ba1e19d
move benchmark submit back
smilesun Feb 26, 2024
e3ef1ab
change mode for submission script
smilesun Feb 26, 2024
98689ff
add output dir
smilesun Feb 27, 2024
eea73e3
added matplotlib from nutan
smilesun Feb 27, 2024
33c027a
logscale
smilesun Feb 27, 2024
d11da9b
complex math display not possible
smilesun Feb 27, 2024
3f1f04f
change output name
smilesun Feb 27, 2024
9907274
remove latex in filename
smilesun Feb 27, 2024
0ac873d
.
smilesun Feb 27, 2024
3edc169
todo
smilesun Feb 27, 2024
75eaaa3
.
smilesun Feb 28, 2024
c8f5f5f
skip draw
smilesun Feb 28, 2024
4e4f490
arrow width
smilesun Feb 28, 2024
cd800eb
plot len
smilesun Feb 28, 2024
1712cac
phase portrait arrow size automatic
smilesun Feb 28, 2024
d3b0a12
color bar for phase portrait
smilesun Feb 28, 2024
5603790
bounding gbox tight
smilesun Feb 28, 2024
9e1eb99
remove \ in filename
smilesun Feb 28, 2024
5364878
more latex
smilesun Feb 28, 2024
7385e81
.
smilesun Feb 28, 2024
0ed3a21
latex in plot
smilesun Feb 28, 2024
3b1bb09
phase portrait neg
smilesun Feb 28, 2024
b70dbd0
log scale to phase portrait plot after carla's suggestion
smilesun Feb 29, 2024
382005e
log single curve
smilesun Feb 29, 2024
f3e9d0f
comment how setpoint model selction works
smilesun Mar 1, 2024
1ae0ef7
.
smilesun Mar 8, 2024
56e06cb
.
smilesun Mar 8, 2024
5e7b9c9
Merge branch 'master' into fbopt
smilesun Mar 8, 2024
6f37464
comment
smilesun Mar 14, 2024
17bfc4a
Merge branch 'fbopt' of github.com:marrlab/DomainLab into fbopt
smilesun Mar 14, 2024
7bdf2ef
dial
smilesun May 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: CI

on:
push:
branches: master
branches: fbopt
pull_request:
branches: master
branches: fbopt
workflow_dispatch:

jobs:
Expand Down
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,39 @@ For development version in Github, see [Installation and Dependencies handling](

We also offer a PyPI version here https://pypi.org/project/domainlab/ which one could install via `pip install domainlab` and it is recommended to create a virtual environment for it.


#### Guide for Helmholtz GPU cluster
```
conda create --name domainlab_py39 python=3.9
conda activate domainlab_py39
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge
conda install torchmetrics==0.10.3
git checkout fbopt
pip install -r requirements_notorch.txt
conda install tensorboard
```

#### Download PACS

step 1:

use the following script to download PACS to your local laptop and upload it to your cluster

https://github.com/marrlab/DomainLab/blob/fbopt/data/script/download_pacs.py

step 2:
make a symbolic link following the example script in https://github.com/marrlab/DomainLab/blob/master/sh_pacs.sh

where `mkdir -p data/pacs` is executed under the repository directory,

`ln -s /dir/to/yourdata/pacs/raw ./data/pacs/PACS`
will create a symbolic link under the repository directory

### Task specification
We offer various ways for the user to specify a scenario to evaluate the generalization performance via training on a limited number of datasets. See detail in
[Task Specification](./docs/doc_tasks.md)


### Example and usage

#### Command line
Expand Down
24 changes: 24 additions & 0 deletions a_reproduce_pacs_diva.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
te_d: sketch
tpath: examples/tasks/task_pacs_aug.py
bs: 32
model: diva
trainer: fbopt
gamma_y: 1.0
ini_setpoint_ratio: 0.99
str_diva_multiplier_type: gammad_recon
coeff_ma_output_state: 0.1
coeff_ma_setpoint: 0.9
exp_shoulder_clip: 5
mu_init: 0.000001
k_i_gain_ratio: 0.5
mu_clip: 10
epos: 1000
epos_min: 200
npath: examples/nets/resnet50domainbed.py
npath_dom: examples/nets/resnet50domainbed.py
es: 2
lr: 0.00005
zx_dim: 0
zy_dim: 64
zd_dim: 64
force_setpoint_change_once: True
9 changes: 7 additions & 2 deletions domainlab/algos/builder_diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""
from domainlab.algos.a_algo_builder import NodeAlgoBuilder
from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor
from domainlab.algos.msels.c_msel_setpoint_delay import MSelSetpointDelay
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.algos.msels.c_msel_val_top_k import MSelValPerfTopK
from domainlab.algos.observers.b_obvisitor import ObVisitor
from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp
from domainlab.algos.observers.c_obvisitor_gen import ObVisitorGen
Expand Down Expand Up @@ -35,7 +37,8 @@ def init_business(self, exp):
request = RequestVAEBuilderCHW(task.isize.c, task.isize.h, task.isize.w, args)
node = VAEChainNodeGetter(request)()
task.get_list_domains_tr_te(args.tr_d, args.te_d)
model = mk_diva(list_str_y=task.list_str_y)(
model = mk_diva(str_diva_multiplier_type=args.str_diva_multiplier_type, list_str_y=task.list_str_y)(

node,
zd_dim=args.zd_dim,
zy_dim=args.zy_dim,
Expand All @@ -48,7 +51,9 @@ def init_business(self, exp):
beta_d=args.beta_d,
)
device = get_device(args)
model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es))
model_sel = MSelSetpointDelay(
MSelOracleVisitor(MSelValPerfTopK(max_es=args.es))
)
if not args.gen:
observer = ObVisitor(model_sel)
else:
Expand Down
21 changes: 21 additions & 0 deletions domainlab/algos/builder_fbopt_dial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
builder for feedback optimization of dial
"""
from domainlab.algos.builder_diva import NodeAlgoBuilderDIVA
from domainlab.algos.trainers.train_fbopt_b import TrainerFbOpt


class NodeAlgoBuilderFbOptDial(NodeAlgoBuilderDIVA):
"""
builder for feedback optimization for dial
"""

def init_business(self, exp):
"""
return trainer, model, observer
"""
trainer_in, model, observer, device = super().init_business(exp)
trainer_in.init_business(model, exp.task, observer, device, exp.args)
trainer = TrainerFbOpt()
trainer.init_business(trainer_in, exp.task, observer, device, exp.args)
return trainer, model, observer, device
4 changes: 3 additions & 1 deletion domainlab/algos/builder_jigen1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""
from domainlab.algos.a_algo_builder import NodeAlgoBuilder
from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor
from domainlab.algos.msels.c_msel_setpoint_delay import MSelSetpointDelay
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.algos.msels.c_msel_val_top_k import MSelValPerfTopK
from domainlab.algos.observers.b_obvisitor import ObVisitor
from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp
from domainlab.algos.trainers.hyper_scheduler import HyperSchedulerWarmupExponential
Expand All @@ -29,7 +31,7 @@ def init_business(self, exp):
task = exp.task
args = exp.args
device = get_device(args)
msel = MSelOracleVisitor(msel=MSelValPerf(max_es=args.es))
msel = MSelSetpointDelay(MSelOracleVisitor(MSelValPerfTopK(max_es=args.es)))
observer = ObVisitor(msel)
observer = ObVisitorCleanUp(observer)

Expand Down
6 changes: 6 additions & 0 deletions domainlab/algos/msels/a_model_sel.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,9 @@ def sel_model_te_acc(self):
if self.msel is not None:
return self.msel.sel_model_te_acc
return -1

@property
def oracle_last_setpoint_sel_te_acc(self):
if self.msel is not None:
return self.msel.oracle_last_setpoint_sel_te_acc
return -1
54 changes: 54 additions & 0 deletions domainlab/algos/msels/c_msel_setpoint_delay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
logs the best up-to-event selected model at each event when setpoint shrinks
"""
from domainlab.algos.msels.a_model_sel import AMSel
from domainlab.utils.logger import Logger


class MSelSetpointDelay(AMSel):
"""
This class decorate another model selection object, it logs the current
selected performance from the decoratee each time the setpoint shrinks
"""

def __init__(self, msel):
super().__init__()
# NOTE: super() has to come first always otherwise self.msel will be overwritten to be None
self.msel = msel
self._oracle_last_setpoint_sel_te_acc = 0.0

@property
def oracle_last_setpoint_sel_te_acc(self):
"""
return the last setpoint best acc
"""
return self._oracle_last_setpoint_sel_te_acc

def update(self, clear_counter=False):
"""
if the best model should be updated
currently, clear_counter is set via
flag = super().tr_epoch(epoch, self.flag_setpoint_updated)
"""
logger = Logger.get_logger()
logger.info(
f"setpoint selected current acc {self._oracle_last_setpoint_sel_te_acc}"
)
if clear_counter:
# for the current version of code, clear_counter = flag_setpoint_updated
log_message = (
f"setpoint msel te acc updated from "
# self._oracle_last_setpoint_sel_te_acc start from 0.0, and always saves
# the test acc when last setpoint decrease occurs
f"{self._oracle_last_setpoint_sel_te_acc} to "
# self.sel_model_te_acc defined as a property
# in a_msel, which returns self.msel.sel_model_te_acc
# is the validation acc based model selection, which
# does not take setpoint into account
f"{self.sel_model_te_acc}"
)
logger.info(log_message)
self._oracle_last_setpoint_sel_te_acc = self.sel_model_te_acc
# let decoratee decide if model should be selected or not
flag = self.msel.update(clear_counter)
return flag
61 changes: 61 additions & 0 deletions domainlab/algos/msels/c_msel_val_top_k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Model Selection should be decoupled from
"""
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.utils.logger import Logger


class MSelValPerfTopK(MSelValPerf):
"""
1. Model selection using validation performance
2. Visitor pattern to trainer
"""

def __init__(self, max_es, top_k=2):
super().__init__(max_es) # construct self.tr_obs (observer)
self.top_k = top_k
self.list_top_k_acc = [0.0 for _ in range(top_k)]

def update(self, clear_counter=False):
"""
if the best model should be updated
"""
flag_super = super().update(clear_counter)
metric_val_current = self.tr_obs.metric_val[self.tr_obs.str_metric4msel]
acc_min = min(self.list_top_k_acc)
if metric_val_current > acc_min:
# overwrite
logger = Logger.get_logger()
logger.info(
f"top k validation acc: {self.list_top_k_acc} \
overwriting/reset counter"
)
self.es_c = 0 # restore counter
ind = self.list_top_k_acc.index(acc_min)
# avoid having identical values
if metric_val_current not in self.list_top_k_acc:
self.list_top_k_acc[ind] = metric_val_current
logger.info(
f"top k validation acc updated: \
{self.list_top_k_acc}"
)
# overwrite to ensure consistency
# issue #569: initially self.list_top_k_acc will be [xx, 0] and it does not matter since 0 will be overwriten by second epoch validation acc.
# actually, after epoch 1, most often, sefl._best_val_acc will be the higher value of self.list_top_k_acc will overwriten by min(self.list_top_k_acc)
logger.info(
f"top-2 val sel: overwriting best val acc from {self._best_val_acc} to "
f"minimum of {self.list_top_k_acc} which is {min(self.list_top_k_acc)} "
f"to ensure consistency"
)
self._best_val_acc = min(self.list_top_k_acc)
# overwrite test acc, this does not depend on if val top-k acc has been overwritten or not
metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel]
if self._sel_model_te_acc != metric_te_current:
# this can only happen if the validation acc has decreased and current val acc is only bigger than min(self.list_top_k_acc} but lower than max(self.list_top_k_acc)
logger.info(
f"top-2 val sel: overwriting selected model test acc from "
f"{self._sel_model_te_acc} to {metric_te_current} to ensure consistency"
)
self._sel_model_te_acc = metric_te_current
return True # if metric_val_current > acc_min:
return flag_super # flag_super is flag from super()=MSelValPerf
27 changes: 25 additions & 2 deletions domainlab/algos/observers/b_obvisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,22 @@ def __init__(self, model_sel):
self.metric_val = None
self.perf_metric = None

self.flag_setpoint_changed_once = False

@property
def str_metric4msel(self):
"""
string representing the metric used for persisting models on the disk
"""
return self.host_trainer.str_metric4msel

def update(self, epoch):
def reset(self):
"""
reset observer via reset model selector
"""
self.model_sel.reset()

def update(self, epoch, flag_info=False):
logger = Logger.get_logger()
logger.info(f"epoch: {epoch}")
self.epo = epoch
Expand All @@ -53,12 +61,18 @@ def update(self, epoch):
self.loader_te, self.device
)
self.metric_te = metric_te
if self.model_sel.update():
if self.model_sel.update(flag_info):
logger.info("better model found")
self.host_trainer.model.save()
logger.info("persisted")
flag_stop = self.model_sel.if_stop()

flag_enough = epoch >= self.host_trainer.aconf.epos_min

self.flag_setpoint_changed_once |= flag_info
if self.host_trainer.aconf.force_setpoint_change_once:
return flag_stop & flag_enough & self.flag_setpoint_changed_once

return flag_stop & flag_enough

def accept(self, trainer):
Expand Down Expand Up @@ -104,6 +118,15 @@ def after_all(self):
metric_te.update({"acc_val": self.model_sel.best_val_acc})
else:
metric_te.update({"acc_val": -1})

if hasattr(self, "model_sel") and hasattr(
self.model_sel, "oracle_last_setpoint_sel_te_acc"
):
metric_te.update(
{"acc_setpoint": self.model_sel.oracle_last_setpoint_sel_te_acc}
)
else:
metric_te.update({"acc_setpoint": -1})
self.dump_prediction(model_ld, metric_te)
# save metric to one line in csv result file
self.host_trainer.model.visitor(metric_te)
Expand Down
18 changes: 15 additions & 3 deletions domainlab/algos/observers/c_obvisitor_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,29 @@ def __init__(self, observer):

def after_all(self):
self.observer.after_all()
self.observer.clean_up()
self.observer.clean_up() # FIXME should be self.clean_up???

def accept(self, trainer):
self.observer.accept(trainer)

def update(self, epoch):
return self.observer.update(epoch)
def update(self, epoch, flag_info=False):
return self.observer.update(epoch, flag_info)

def clean_up(self):
self.observer.clean_up()

@property
def model_sel(self):
return self.observer.model_sel

@model_sel.setter
def model_sel(self, model_sel):
self.observer.model_sel = model_sel

@property
def metric_te(self):
return self.observer.metric_te

@property
def metric_val(self):
return self.observer.metric_val
Loading
Loading