-
Notifications
You must be signed in to change notification settings - Fork 6
/
bart.py
34 lines (29 loc) · 1.39 KB
/
bart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from transformers import T5ForConditionalGeneration, BartForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
from utils import label_smoothed_nll_loss
class MyBart(BartForConditionalGeneration):
def forward(self, input_ids, attention_mask=None, encoder_outputs=None,
decoder_input_ids=None, decoder_attention_mask=None, decoder_cached_states=None,
use_cache=False, is_training=False):
if is_training:
_decoder_input_ids = shift_tokens_right(decoder_input_ids, self.config.pad_token_id)
else:
_decoder_input_ids = decoder_input_ids
outputs = self.model(
input_ids,
attention_mask=attention_mask,
encoder_outputs=encoder_outputs,
decoder_input_ids=_decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_cached_states=decoder_cached_states,
use_cache=use_cache,
)
lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
if is_training:
lprobs = F.log_softmax(lm_logits, dim=-1)
loss, _ = label_smoothed_nll_loss(lprobs, decoder_input_ids, epsilon=0.1, ignore_index=self.config.pad_token_id)
return loss
return (lm_logits, ) + outputs[1:]