initial version for roberta squad

This commit is contained in:
erenup
2019-12-13 14:51:40 +08:00
parent 40ed717232
commit 9b312f9d41
5 changed files with 106 additions and 19 deletions

View File

@@ -39,6 +39,7 @@ from tqdm import tqdm, trange
from transformers import (WEIGHTS_NAME, BertConfig,
BertForQuestionAnswering, BertTokenizer,
RobertaForQuestionAnswering, RobertaTokenizer, RobertaConfig,
XLMConfig, XLMForQuestionAnswering,
XLMTokenizer, XLNetConfig,
XLNetForQuestionAnswering,
@@ -53,10 +54,11 @@ from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_e
logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)), ())
MODEL_CLASSES = {
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
'roberta': (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
@@ -141,13 +143,11 @@ def train(args, train_dataset, model, tokenizer):
inputs = {
'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
'start_positions': batch[3],
'end_positions': batch[4]
'end_positions': batch[4],
}
if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[5], 'p_mask': batch[6]})
@@ -241,12 +241,9 @@ def evaluate(args, model, tokenizer, prefix=""):
with torch.no_grad():
inputs = {
'input_ids': batch[0],
'attention_mask': batch[1]
'attention_mask': batch[1],
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
}
if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
example_indices = batch[3]
# XLNet and XLM use more arguments for their predictions
@@ -311,7 +308,7 @@ def evaluate(args, model, tokenizer, prefix=""):
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold)
args.version_2_with_negative, args.null_score_diff_threshold, tokenizer)
# Compute the F1 and exact scores.
results = squad_evaluate(examples, predictions)