-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from boostcampaitech6/feat/cnn
feat : cnn layer 추가
- Loading branch information
Showing
2 changed files
with
484 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
|
||
from typing import Optional, Union, Tuple,List | ||
|
||
import torch | ||
import torch.utils.checkpoint | ||
from torch import nn | ||
from torch.nn import CrossEntropyLoss | ||
from transformers import RobertaPreTrainedModel | ||
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions,QuestionAnsweringModelOutput | ||
from transformers.models.roberta.modeling_roberta import RobertaEmbeddings, RobertaEncoder, RobertaPooler, RobertaModel | ||
|
||
class CNN_block(nn.Module): | ||
def __init__(self, input_size, hidden_size): | ||
super(CNN_block, self).__init__() | ||
self.conv1 = nn.Conv1d(in_channels = input_size, out_channels = input_size * 2, kernel_size = 3, padding = 1) | ||
self.conv2 = nn.Conv1d(in_channels = input_size*2, out_channels = input_size, kernel_size = 1) | ||
self.relu = nn.ReLU() | ||
self.layer_norm = nn.LayerNorm(hidden_size) | ||
|
||
def forward(self, x): | ||
output = self.conv1(x) | ||
output = self.conv2(output) | ||
output = x + self.relu(output) | ||
output = self.layer_norm(output) | ||
|
||
return output | ||
|
||
|
||
class CNN_RobertaForQuestionAnswering(RobertaPreTrainedModel): | ||
def __init__(self, config): | ||
super().__init__(config) | ||
self.num_labels = config.num_labels | ||
self.roberta = RobertaModel(config, add_pooling_layer=False) | ||
self.cnn_block1 = CNN_block( config.max_seq_len, config.hidden_size) | ||
self.cnn_block2 = CNN_block( config.max_seq_len, config.hidden_size) | ||
self.cnn_block3 = CNN_block( config.max_seq_len, config.hidden_size) | ||
self.cnn_block4 = CNN_block( config.max_seq_len, config.hidden_size) | ||
self.cnn_block5 = CNN_block( config.max_seq_len, config.hidden_size) | ||
|
||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) | ||
|
||
# Initialize weights and apply final processing | ||
self.post_init() | ||
|
||
def forward( | ||
self, | ||
input_ids: Optional[torch.LongTensor] = None, | ||
attention_mask: Optional[torch.FloatTensor] = None, | ||
token_type_ids: Optional[torch.LongTensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
head_mask: Optional[torch.FloatTensor] = None, | ||
inputs_embeds: Optional[torch.FloatTensor] = None, | ||
start_positions: Optional[torch.LongTensor] = None, | ||
end_positions: Optional[torch.LongTensor] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: | ||
r""" | ||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | ||
Labels for position (index) of the start of the labelled span for computing the token classification loss. | ||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence | ||
are not taken into account for computing the loss. | ||
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | ||
Labels for position (index) of the end of the labelled span for computing the token classification loss. | ||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence | ||
are not taken into account for computing the loss. | ||
""" | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
outputs = self.roberta( | ||
input_ids, | ||
attention_mask=attention_mask, | ||
token_type_ids=token_type_ids, | ||
position_ids=position_ids, | ||
head_mask=head_mask, | ||
inputs_embeds=inputs_embeds, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) | ||
|
||
sequence_output = outputs[0] | ||
|
||
# CNN layer 5개 | ||
sequence_output = self.cnn_block1(sequence_output) | ||
sequence_output = self.cnn_block2(sequence_output) | ||
sequence_output = self.cnn_block3(sequence_output) | ||
sequence_output = self.cnn_block4(sequence_output) | ||
sequence_output = self.cnn_block5(sequence_output) | ||
|
||
logits = self.qa_outputs(sequence_output) | ||
start_logits, end_logits = logits.split(1, dim=-1) | ||
start_logits = start_logits.squeeze(-1).contiguous() | ||
end_logits = end_logits.squeeze(-1).contiguous() | ||
|
||
total_loss = None | ||
if start_positions is not None and end_positions is not None: | ||
# If we are on multi-GPU, split add a dimension | ||
if len(start_positions.size()) > 1: | ||
start_positions = start_positions.squeeze(-1) | ||
if len(end_positions.size()) > 1: | ||
end_positions = end_positions.squeeze(-1) | ||
# sometimes the start/end positions are outside our model inputs, we ignore these terms | ||
ignored_index = start_logits.size(1) | ||
start_positions = start_positions.clamp(0, ignored_index) | ||
end_positions = end_positions.clamp(0, ignored_index) | ||
|
||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) | ||
start_loss = loss_fct(start_logits, start_positions) | ||
end_loss = loss_fct(end_logits, end_positions) | ||
total_loss = (start_loss + end_loss) / 2 | ||
|
||
if not return_dict: | ||
output = (start_logits, end_logits) + outputs[2:] | ||
return ((total_loss,) + output) if total_loss is not None else output | ||
|
||
return QuestionAnsweringModelOutput( | ||
loss=total_loss, | ||
start_logits=start_logits, | ||
end_logits=end_logits, | ||
hidden_states=outputs.hidden_states, | ||
attentions=outputs.attentions, | ||
) | ||
|
||
|
Oops, something went wrong.