Skip to content

Commit

Permalink
[CheckpointSaver] Add saving listeners support for increment checkpoi…
Browse files Browse the repository at this point in the history
…nt saver. (#915)

Signed-off-by: chenbangduo.cbd <chenbangduo.cbd@alibaba-inc.com>
  • Loading branch information
JackMoriarty authored Jul 26, 2023
1 parent e2037de commit 4af2db0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
16 changes: 13 additions & 3 deletions tensorflow/python/training/basic_session_run_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,7 @@ def after_run(self, run_context, run_values):
global_step = run_context.session.run(self._global_step_tensor)
if self._incremental_timer.should_trigger_for_step(global_step):
self._incremental_timer.update_last_triggered_step(global_step)
logging.info("Start Save incremental checkpoints for %d into %s.", global_step, self._incremental_save_path)
self._get_incr_saver().incremental_save(run_context.session, self._incremental_save_path, global_step=global_step)
logging.info("Finish Save incremental checkpoints for %d into %s.", global_step, self._incremental_save_path)
self._incr_save(run_context.session, global_step)


def end(self, session):
Expand Down Expand Up @@ -666,6 +664,18 @@ def _get_saver(self):
self._saver = savers[0]
return savers[0]

def _incr_save(self, session, step):
logging.info("Saving incremental checkpoints for %d into %s.", step,
self._incremental_save_path)
for l in self._listeners:
l.before_save(session, step)

self._get_incr_saver().incremental_save(session,
self._incremental_save_path,
global_step=step)
for l in self._listeners:
l.after_save(session, step)

def _get_incr_saver(self):
if self._scaffold is not None:
return self._scaffold._incr_saver
Expand Down
7 changes: 6 additions & 1 deletion tensorflow/python/training/monitored_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,8 @@ def MonitoredTrainingSession(
save_checkpoint_steps=USE_DEFAULT,
summary_dir=None,
save_incremental_checkpoint_secs=None,
target_nodes_or_tensors=None):
target_nodes_or_tensors=None,
saving_listeners=None):

"""Creates a `MonitoredSession` for training.
Expand Down Expand Up @@ -548,6 +549,9 @@ def MonitoredTrainingSession(
summaries. If None, checkpoint_dir is used instead.
target_nodes_or_tensors: list of tf.Tensor or tf.Operation indicates
targets, which determine graph transformation of 'smart-stage'
saving_listeners: List of `CheckpointSaverListener` subclass instances. Used
for callbacks that run immediately before or after this hook saves the
checkpoint.
Returns:
A `MonitoredSession` object.
Expand Down Expand Up @@ -648,6 +652,7 @@ def MonitoredTrainingSession(
save_steps=save_checkpoint_steps,
save_secs=save_checkpoint_secs,
scaffold=scaffold,
listeners=saving_listeners,
incremental_save_secs=save_incremental_checkpoint_secs))

if hooks:
Expand Down

0 comments on commit 4af2db0

Please sign in to comment.