| | """ |
| | Original work: |
| | https://github.com/sangHa0411/CloneDetection/blob/main/models/codebert.py#L169 |
| | |
| | Copyright (c) 2022 Sangha Park(sangha110495), Young Jin Ahn(snoop2head) |
| | |
| | All credits to the original authors. |
| | """ |
| | import torch.nn as nn |
| | from transformers import ( |
| | RobertaPreTrainedModel, |
| | RobertaModel, |
| | ) |
| | from transformers.modeling_outputs import SequenceClassifierOutput |
| |
|
| |
|
| | class CloneDetectionModel(RobertaPreTrainedModel): |
| | _keys_to_ignore_on_load_missing = [r"position_ids"] |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.num_labels = config.num_labels |
| | self.config = config |
| |
|
| | self.roberta = RobertaModel(config, add_pooling_layer=False) |
| | self.net = nn.Sequential( |
| | nn.Dropout(config.hidden_dropout_prob), |
| | nn.Linear(config.hidden_size, config.hidden_size), |
| | nn.ReLU(), |
| | ) |
| | self.classifier = nn.Linear(config.hidden_size * 4, config.num_labels) |
| |
|
| | def forward( |
| | self, |
| | input_ids=None, |
| | attention_mask=None, |
| | token_type_ids=None, |
| | position_ids=None, |
| | head_mask=None, |
| | inputs_embeds=None, |
| | labels=None, |
| | output_attentions=None, |
| | output_hidden_states=None, |
| | return_dict=None, |
| | ): |
| |
|
| | return_dict = ( |
| | return_dict if return_dict is not None else self.config.use_return_dict |
| | ) |
| |
|
| | outputs = self.roberta( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | inputs_embeds=inputs_embeds, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | hidden_states = outputs[0] |
| | batch_size, _, hidden_size = hidden_states.shape |
| |
|
| | |
| | cls_flag = input_ids == self.config.tokenizer_cls_token_id |
| | sep_flag = input_ids == self.config.tokenizer_sep_token_id |
| |
|
| | special_token_states = hidden_states[cls_flag + sep_flag].view( |
| | batch_size, -1, hidden_size |
| | ) |
| | special_hidden_states = self.net( |
| | special_token_states |
| | ) |
| |
|
| | pooled_output = special_hidden_states.view( |
| | batch_size, -1 |
| | ) |
| | logits = self.classifier(pooled_output) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| |
|
| | if not return_dict: |
| | output = (logits,) + outputs[2:] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return SequenceClassifierOutput( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|