-
Notifications
You must be signed in to change notification settings - Fork 1
/
generate.py
44 lines (39 loc) · 1.25 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#!/usr/bin/env python3
import logging
from pathlib import Path
from ztfparsnip import io
from ztfparsnip.create import CreateLightcurves
from ztfparsnip.train import Train
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# weights = {"sn_ia": 9400, "tde": 9400, "sn_other": 9400, "agn": 9400, "star": 9400}
weights = {"sn_ia": 15650, "tde": 15650, "sn_other": 15650, "agn": 15650, "star": 15650}
if __name__ == "__main__":
# sample = CreateLightcurves(
# output_format="parsnip",
# classkey="simpleclasses",
# weights=weights,
# train_dir=Path("train_parsnip_fixstar"),
# plot_dir=Path("plot"),
# test_dir=Path("test_parsnip_fixstar"),
# seed=0,
# phase_lim=False,
# k_corr=True,
# test_fraction=0.0,
# )
# sample.select()
# sample.create(
# plot_debug=False, subsampling_rate=0.9, jd_scatter_sigma=0.03, start=0
# )
train = Train(
data_dir=Path("data"),
classkey="simpleclasses",
train_validation_fraction=0.7,
no_redshift=False,
seed=0,
)
# train.run()
train.classify(
# model_path=Path("data") / "models" / "train_bts_all_model_with_z.hd5"
)
# train.evaluate()