Skip to content

Commit

Permalink
Compatible upgrade of sparse_momentum for master param (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
GuoxiaWang authored Mar 1, 2022
1 parent ef0a32b commit 731eb3f
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions dynamic/utils/hybrid_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,20 @@ def _append_optimize_op(self, block, param_and_grad):
if getattr(param_and_grad[0], 'is_sparse_grad', None):
index = getattr(param_and_grad[0], 'index', None)
axis = getattr(param_and_grad[0], 'axis', None)
_, _ = paddle._C_ops.sparse_momentum(
param_and_grad[0], param_and_grad[1], velocity_acc, index, lr,
param_and_grad[0], velocity_acc, 'mu', self._momentum,
'use_nesterov', self._use_nesterov, 'regularization_method',
self._regularization_method, 'regularization_coeff',
self._regularization_coeff, 'axis', axis)
try:
_, _ = paddle._C_ops.sparse_momentum(
param_and_grad[0], param_and_grad[1], velocity_acc, index, lr,
param_and_grad[0], velocity_acc, 'mu', self._momentum,
'use_nesterov', self._use_nesterov, 'regularization_method',
self._regularization_method, 'regularization_coeff',
self._regularization_coeff, 'axis', axis)
except:
_, _, _ = paddle._C_ops.sparse_momentum(
param_and_grad[0], param_and_grad[1], velocity_acc, index, lr, master_weight,
param_and_grad[0], velocity_acc, master_weight, 'mu', self._momentum,
'use_nesterov', self._use_nesterov, 'regularization_method',
self._regularization_method, 'regularization_coeff',
self._regularization_coeff, 'axis', axis, 'multi_precision', find_master)
else:
_, _, _ = paddle._C_ops.momentum(
param_and_grad[0], param_and_grad[1], velocity_acc, lr,
Expand Down

0 comments on commit 731eb3f

Please sign in to comment.