-
Notifications
You must be signed in to change notification settings - Fork 0
/
example.py
31 lines (24 loc) · 1.03 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import sys
from megatron_wrap.core import MegatronWrap
from megatron_wrap.utils import logger
from megatron_wrap.utils.dataset import EmptyDataset, ExampleSftDataset
config_file = sys.argv[1]
wrap = MegatronWrap(config_file)
if wrap.get_flow_key() == "minimal_mock":
dataset = EmptyDataset()
elif wrap.get_flow_key() == "gpt_sft":
dataset = ExampleSftDataset()
else:
raise ValueError()
train_iters = wrap.get_common_args().train_iters
save_interval = wrap.get_common_args().save_interval
global_batch_size = wrap.get_common_args().global_batch_size
logger.info_rank_0(f"using config {config_file}, train for {train_iters} iters, save after each {save_interval} iters")
wrap.initialize()
wrap.setup_model_and_optimizer()
for current_iteration in range(train_iters):
metrics = wrap.train(dataset.get_batch(global_batch_size))
wrap.log_last_metrics()
if save_interval is not None:
if (current_iteration % save_interval == 0 and current_iteration > 0) or (current_iteration==train_iters-1):
wrap.save()