initial version for roberta squad
This commit is contained in:
@@ -39,6 +39,7 @@ from tqdm import tqdm, trange
|
|||||||
|
|
||||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
from transformers import (WEIGHTS_NAME, BertConfig,
|
||||||
BertForQuestionAnswering, BertTokenizer,
|
BertForQuestionAnswering, BertTokenizer,
|
||||||
|
RobertaForQuestionAnswering, RobertaTokenizer, RobertaConfig,
|
||||||
XLMConfig, XLMForQuestionAnswering,
|
XLMConfig, XLMForQuestionAnswering,
|
||||||
XLMTokenizer, XLNetConfig,
|
XLMTokenizer, XLNetConfig,
|
||||||
XLNetForQuestionAnswering,
|
XLNetForQuestionAnswering,
|
||||||
@@ -53,10 +54,11 @@ from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_e
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
||||||
for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
|
for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)), ())
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||||
|
'roberta': (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
|
||||||
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||||
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||||
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
||||||
@@ -141,13 +143,11 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
inputs = {
|
inputs = {
|
||||||
'input_ids': batch[0],
|
'input_ids': batch[0],
|
||||||
'attention_mask': batch[1],
|
'attention_mask': batch[1],
|
||||||
|
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
|
||||||
'start_positions': batch[3],
|
'start_positions': batch[3],
|
||||||
'end_positions': batch[4]
|
'end_positions': batch[4],
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.model_type != 'distilbert':
|
|
||||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
|
|
||||||
|
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ['xlnet', 'xlm']:
|
||||||
inputs.update({'cls_index': batch[5], 'p_mask': batch[6]})
|
inputs.update({'cls_index': batch[5], 'p_mask': batch[6]})
|
||||||
|
|
||||||
@@ -241,12 +241,9 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {
|
inputs = {
|
||||||
'input_ids': batch[0],
|
'input_ids': batch[0],
|
||||||
'attention_mask': batch[1]
|
'attention_mask': batch[1],
|
||||||
|
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.model_type != 'distilbert':
|
|
||||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
|
||||||
|
|
||||||
example_indices = batch[3]
|
example_indices = batch[3]
|
||||||
|
|
||||||
# XLNet and XLM use more arguments for their predictions
|
# XLNet and XLM use more arguments for their predictions
|
||||||
@@ -311,7 +308,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
|
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
|
||||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
args.version_2_with_negative, args.null_score_diff_threshold, tokenizer)
|
||||||
|
|
||||||
# Compute the F1 and exact scores.
|
# Compute the F1 and exact scores.
|
||||||
results = squad_evaluate(examples, predictions)
|
results = squad_evaluate(examples, predictions)
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ if is_torch_available():
|
|||||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel,
|
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel,
|
||||||
RobertaForSequenceClassification, RobertaForMultipleChoice,
|
RobertaForSequenceClassification, RobertaForMultipleChoice,
|
||||||
RobertaForTokenClassification,
|
RobertaForTokenClassification, RobertaForQuestionAnswering,
|
||||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
from .modeling_distilbert import (DistilBertPreTrainedModel, DistilBertForMaskedLM, DistilBertModel,
|
from .modeling_distilbert import (DistilBertPreTrainedModel, DistilBertForMaskedLM, DistilBertModel,
|
||||||
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
|
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
|
||||||
|
|||||||
@@ -377,7 +377,8 @@ def compute_predictions_logits(
|
|||||||
output_null_log_odds_file,
|
output_null_log_odds_file,
|
||||||
verbose_logging,
|
verbose_logging,
|
||||||
version_2_with_negative,
|
version_2_with_negative,
|
||||||
null_score_diff_threshold
|
null_score_diff_threshold,
|
||||||
|
tokenizer,
|
||||||
):
|
):
|
||||||
"""Write final predictions to the json file and log-odds of null if needed."""
|
"""Write final predictions to the json file and log-odds of null if needed."""
|
||||||
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
||||||
@@ -474,11 +475,14 @@ def compute_predictions_logits(
|
|||||||
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||||
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||||
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
||||||
tok_text = " ".join(tok_tokens)
|
|
||||||
|
|
||||||
# De-tokenize WordPieces that have been split off.
|
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
||||||
tok_text = tok_text.replace(" ##", "")
|
|
||||||
tok_text = tok_text.replace("##", "")
|
# tok_text = " ".join(tok_tokens)
|
||||||
|
#
|
||||||
|
# # De-tokenize WordPieces that have been split off.
|
||||||
|
# tok_text = tok_text.replace(" ##", "")
|
||||||
|
# tok_text = tok_text.replace("##", "")
|
||||||
|
|
||||||
# Clean whitespace
|
# Clean whitespace
|
||||||
tok_text = tok_text.strip()
|
tok_text = tok_text.strip()
|
||||||
|
|||||||
@@ -140,7 +140,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
tok_to_orig_index.append(i)
|
tok_to_orig_index.append(i)
|
||||||
all_doc_tokens.append(sub_token)
|
all_doc_tokens.append(sub_token)
|
||||||
|
|
||||||
|
|
||||||
if is_training and not example.is_impossible:
|
if is_training and not example.is_impossible:
|
||||||
tok_start_position = orig_to_tok_index[example.start_position]
|
tok_start_position = orig_to_tok_index[example.start_position]
|
||||||
if example.end_position < len(example.doc_tokens) - 1:
|
if example.end_position < len(example.doc_tokens) - 1:
|
||||||
@@ -155,7 +154,8 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
spans = []
|
spans = []
|
||||||
|
|
||||||
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
|
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
|
||||||
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence
|
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence + 1 \
|
||||||
|
if 'roberta' in str(type(tokenizer)) else tokenizer.max_len - tokenizer.max_len_single_sentence
|
||||||
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
|
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
|
||||||
|
|
||||||
span_doc_tokens = all_doc_tokens
|
span_doc_tokens = all_doc_tokens
|
||||||
|
|||||||
@@ -555,3 +555,89 @@ class RobertaClassificationHead(nn.Module):
|
|||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings("""Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||||
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
|
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
|
||||||
|
class RobertaForQuestionAnswering(BertPreTrainedModel):
|
||||||
|
r"""
|
||||||
|
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
|
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||||
|
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||||
|
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||||
|
Span-start scores (before SoftMax).
|
||||||
|
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||||
|
Span-end scores (before SoftMax).
|
||||||
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||||
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||||
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||||
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
Examples::
|
||||||
|
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
||||||
|
model = RobertaForMultipleChoice.from_pretrained('roberta-base')
|
||||||
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||||
|
start_positions = torch.tensor([1])
|
||||||
|
end_positions = torch.tensor([3])
|
||||||
|
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
|
||||||
|
loss, start_scores, end_scores = outputs[:2]
|
||||||
|
"""
|
||||||
|
config_class = RobertaConfig
|
||||||
|
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
base_model_prefix = "roberta"
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super(RobertaForQuestionAnswering, self).__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
|
self.roberta = RobertaModel(config)
|
||||||
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
|
start_positions=None, end_positions=None):
|
||||||
|
|
||||||
|
outputs = self.roberta(input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
logits = self.qa_outputs(sequence_output)
|
||||||
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
|
start_logits = start_logits.squeeze(-1)
|
||||||
|
end_logits = end_logits.squeeze(-1)
|
||||||
|
|
||||||
|
outputs = (start_logits, end_logits,) + outputs[2:]
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
# If we are on multi-GPU, split add a dimension
|
||||||
|
if len(start_positions.size()) > 1:
|
||||||
|
start_positions = start_positions.squeeze(-1)
|
||||||
|
if len(end_positions.size()) > 1:
|
||||||
|
end_positions = end_positions.squeeze(-1)
|
||||||
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
|
ignored_index = start_logits.size(1)
|
||||||
|
start_positions.clamp_(0, ignored_index)
|
||||||
|
end_positions.clamp_(0, ignored_index)
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
|
total_loss = (start_loss + end_loss) / 2
|
||||||
|
outputs = (total_loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||||
Reference in New Issue
Block a user