Skip to content

Commit

Permalink
Merge pull request #66 from epfLLM/tokens_per_second
Browse files Browse the repository at this point in the history
Tokens per second metric
  • Loading branch information
martinjaggi authored Sep 25, 2023
2 parents d7e3d04 + 5b4ae47 commit 4b7e20b
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 5 deletions.
18 changes: 16 additions & 2 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import torch

from megatron import get_args, get_tokenizer, get_timers, print_rank_0
from megatron import get_args, get_tokenizer, get_timers, get_counters, print_rank_0
from megatron.training import pretrain
from megatron.core import tensor_parallel
from megatron.core.parallel_state import get_data_parallel_group
from megatron.model import GPTModel, ModelType, LlamaModel, FalconModel
from megatron.utils import get_ltor_masks_and_position_ids, average_losses_across_data_parallel_group
from megatron.data.gpt_dataset import build_train_valid_test_datasets as gpt_build_datasets
Expand Down Expand Up @@ -119,8 +120,21 @@ def get_batch(data_iterator):
tokens = data_b["text"]
labels = tokens[:, 1:].contiguous()
tokens = tokens[:, :-1].contiguous()
if args.data_type == "gpt":

# Update tokens counter.
counters = get_counters()
n_tokens = torch.tensor(tokens.numel(), device=tokens.device)
if args.data_parallel_size == 1:
n_tokens = n_tokens.item()
else:
group = get_data_parallel_group()
torch.distributed.all_reduce(
n_tokens, op=torch.distributed.ReduceOp.SUM, group=group
)
n_tokens = n_tokens.item()
counters["tokens"] += n_tokens

if args.data_type == "gpt":
# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
Expand Down
1 change: 1 addition & 0 deletions megatron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
from .global_vars import get_counters

from .utils import (print_rank_0,
print_all_nodes,
Expand Down
15 changes: 15 additions & 0 deletions megatron/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os
import sys
from collections import defaultdict

from megatron import dist_signal_handler
from megatron.tokenizer import build_tokenizer
Expand All @@ -17,6 +18,7 @@
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None
_GLOBAL_SIGNAL_HANDLER = None
_GLOBAL_COUNTERS = None


def get_args():
Expand Down Expand Up @@ -62,6 +64,12 @@ def get_timers():
return _GLOBAL_TIMERS


def get_counters():
"""Return counters."""
_ensure_var_is_initialized(_GLOBAL_COUNTERS, 'counters')
return _GLOBAL_COUNTERS


def get_signal_handler():
_ensure_var_is_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
return _GLOBAL_SIGNAL_HANDLER
Expand Down Expand Up @@ -90,6 +98,7 @@ def set_global_variables(args):
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
_set_timers(args)
_set_counters(args)

if args.exit_signal_handler:
_set_signal_handler()
Expand Down Expand Up @@ -178,6 +187,12 @@ def _set_timers(args):
_GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option)


def _set_counters(args):
global _GLOBAL_COUNTERS
_ensure_var_is_not_initialized(_GLOBAL_COUNTERS, 'counters')
_GLOBAL_COUNTERS = defaultdict(int)


def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is not None, '{} is not initialized.'.format(name)
Expand Down
3 changes: 0 additions & 3 deletions megatron/timers.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,3 @@ def write(self, names, writer, iteration, normalizer=1.0,
for name in name_to_min_max_time:
_, max_time = name_to_min_max_time[name]
writer.add_scalar(name + '-time', max_time, iteration)
# if using wandb writer, flush the stats we just filled here, close to the creation time
if hasattr(writer,"flush_all"):
writer.flush_all()
10 changes: 10 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from megatron import get_args
from megatron import get_signal_handler
from megatron import get_timers
from megatron import get_counters
from megatron import get_tensorboard_writer
from megatron import get_current_global_batch_size
from megatron import get_num_microbatches
Expand Down Expand Up @@ -590,17 +591,22 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
if iteration % args.log_interval == 0:
elapsed_time = timers('interval-time').elapsed(barrier=True)
elapsed_time_per_iteration = elapsed_time / total_iterations
counters = get_counters()
tokens = counters.pop('tokens') # reset counter for future iterations
tokens_per_sec = tokens/(elapsed_time)
if writer:
if args.log_timers_to_tensorboard:
writer.add_scalar('iteration-time',
elapsed_time_per_iteration, iteration)
writer.add_scalar('tokens-per-sec', tokens_per_sec, iteration)

log_string = ' iteration {:8d}/{:8d} |'.format(
iteration, args.train_iters)
log_string += ' consumed samples: {:12d} |'.format(
args.consumed_train_samples)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time_per_iteration * 1000.0)
log_string += f' rate (tokens/sec): {tokens_per_sec:.2f} |'
log_string += ' learning rate: {:.3E} |'.format(learning_rate)
log_string += ' global batch size: {:5d} |'.format(batch_size)
for key in total_loss_dict:
Expand Down Expand Up @@ -668,6 +674,7 @@ def _train(args, forward_step_func,
# Iterations.
iteration = args.iteration

counters = get_counters()
timers('interval-time', log_level=0).start(barrier=True)
print_datetime('before the start of training step')
report_memory_flag = True
Expand Down Expand Up @@ -706,10 +713,13 @@ def _train(args, forward_step_func,
if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
prefix = 'iteration {}'.format(iteration)
current_tokens = counters['tokens']
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func,
verbose=False, args=args)
counters['tokens'] = current_tokens


# if using wandb writer, flush the stats of train_step & potentially evaluate
writer = get_tensorboard_writer()
Expand Down

0 comments on commit 4b7e20b

Please sign in to comment.