initial version for roberta squad
This commit is contained in:
@@ -39,6 +39,7 @@ from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (WEIGHTS_NAME, BertConfig,
|
||||
BertForQuestionAnswering, BertTokenizer,
|
||||
RobertaForQuestionAnswering, RobertaTokenizer, RobertaConfig,
|
||||
XLMConfig, XLMForQuestionAnswering,
|
||||
XLMTokenizer, XLNetConfig,
|
||||
XLNetForQuestionAnswering,
|
||||
@@ -53,10 +54,11 @@ from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_e
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
||||
for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
|
||||
for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
'roberta': (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
|
||||
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
||||
@@ -141,13 +143,11 @@ def train(args, train_dataset, model, tokenizer):
|
||||
inputs = {
|
||||
'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
|
||||
'start_positions': batch[3],
|
||||
'end_positions': batch[4]
|
||||
'end_positions': batch[4],
|
||||
}
|
||||
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
|
||||
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[5], 'p_mask': batch[6]})
|
||||
|
||||
@@ -241,12 +241,9 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
with torch.no_grad():
|
||||
inputs = {
|
||||
'input_ids': batch[0],
|
||||
'attention_mask': batch[1]
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
|
||||
}
|
||||
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
||||
|
||||
example_indices = batch[3]
|
||||
|
||||
# XLNet and XLM use more arguments for their predictions
|
||||
@@ -311,7 +308,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||
args.version_2_with_negative, args.null_score_diff_threshold, tokenizer)
|
||||
|
||||
# Compute the F1 and exact scores.
|
||||
results = squad_evaluate(examples, predictions)
|
||||
|
||||
Reference in New Issue
Block a user