Skip to content

Commit

Permalink
loss mask before split
Browse files Browse the repository at this point in the history
  • Loading branch information
RangiLyu committed Jul 11, 2024
1 parent ce9bb62 commit 8753d9a
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions xtuner/model/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def compute_loss(self, data, data_samples=None):
data['labels'] = torch.cat(
(data['labels'][:, 1:], torch.zeros_like(data['labels'][:, :1])),
dim=1)
tmp_label = data['labels'].clone()
tmp_label[tmp_label == 0] = -100
all_loss_mask = data[
'labels'] != -100 # loss mask of all tokens in all sp ranks # noqa

if get_sequence_parallel_world_size() > 1:
data = self._split_for_sequence_parallel(data)
Expand All @@ -161,7 +165,7 @@ def compute_loss(self, data, data_samples=None):

labels = data['labels']
labels[labels == -100] = 0
loss_mask = labels != 0
loss_mask = labels != 0 # loss mask in a single sp rank
policy_logps = self._gather_masked_logits(all_logits, labels,
loss_mask)
ref_logps = self._gather_masked_logits(all_ref_logits, labels,
Expand All @@ -183,15 +187,15 @@ def compute_loss(self, data, data_samples=None):
(policy_chosen_logps, policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps) = self.get_logps(
policy_logps, ref_logps, loss_mask)
policy_logps, ref_logps, all_loss_mask)
else:
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cu_seqlens = message_hub.get_info(f'cumulative_len_rank_{rank}')
(policy_chosen_logps, policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps) = self.get_var_len_atten_logps(
policy_logps, ref_logps, loss_mask, cu_seqlens,
policy_logps, ref_logps, all_loss_mask, cu_seqlens,
data['attention_mask'])

pi_logratios = policy_chosen_logps - policy_rejected_logps
Expand Down

0 comments on commit 8753d9a

Please sign in to comment.